jinnax.data

  1#Functions to process data for training
  2import pandas
  3import jax
  4import jax.numpy as jnp
  5from jax import random
  6import numpy as np
  7import random
  8import sys
  9from PIL import Image
 10from IPython.display import display
 11__docformat__ = "numpy"
 12
 13#Generate d-dimensional data for PINN training
 14def generate_PINNdata(u,xl,xu,tl = None,tu = None,Ns = None,Nts = None,Nb = None,Ntb = None,N0 = None,Nc = None,Ntc = None,train = True,d = 1,p = 1,poss = 'grid',posts = 'grid',pos0 = 'grid',posb = 'grid',postb = 'grid',posc = 'grid',postc = 'grid',sigmas = 0,sigmab = 0,sigma0 = 0):
 15    """
 16    Generate spatio-temporal data in a d-dimensional cube for PINN simulation
 17    ----------
 18
 19    Parameters
 20    ----------
 21    u : function
 22
 23        The function u(x,t) solution of the PDE
 24
 25    xl : float
 26
 27        Lower bound of each x coordinate
 28
 29    xu : float
 30
 31        Upper bound of each x coordinate
 32
 33    tl : float
 34
 35        Lower bound of the time interval
 36
 37    tu : float
 38
 39        Upper bound of the time interval
 40
 41    Ns : int, None
 42
 43        Number of points along each x coordinate for sensor data. None for not generating sensor data
 44
 45    Nts : int, None
 46
 47        Number of points along the time axis for sensor data. None for not generating sensor data
 48
 49    Nb : int, None
 50
 51        Number of points along each x coordinate for boundary data. None for not generating boundary data
 52
 53    Ntb : int, None
 54
 55        Number of points along the time axis for boundary data. None for not generating boundary data
 56
 57    N0 : int, None
 58
 59        Number of points along each x coordinate for initial data. None for not generating initial data
 60
 61    Nc : int, None
 62
 63        Number of points along each x coordinate for collocation points. None for not generating collocation points
 64
 65    Ntc : int, None
 66
 67        Number of points along the time axis for collocation points. None for not generating collocation points
 68
 69    train : logical
 70
 71        Whether to generate train (True) or test (False) data. Only sensor data is generated for test data. Default True
 72
 73    d : int
 74
 75        Domain dimension. Default 1
 76
 77    p : int
 78
 79        Output dimension. Default 1
 80
 81    poss : str
 82
 83        Position of sensor data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 84
 85    posts : str
 86
 87        Position of sensor data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 88
 89    posb : int
 90
 91        Position of boundary data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 92
 93    postb : int
 94
 95        Position of boundary data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 96
 97    pos0 : int
 98
 99        Position of initial data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
100
101    posc : str
102
103        Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
104
105    postc : str
106
107        Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
108
109    sigmas : str
110
111        Standard deviation of the Gaussian noise of sensor data. Default 0
112
113    sigmab : str
114
115        Standard deviation of the Gaussian noise of boundary data. Default 0
116
117    sigma0 : str
118
119        Standard deviation of the Gaussian noise of initial data. Default 0
120
121    Returns
122    -------
123
124    dict-like object with generated data
125
126    """
127
128    #Repeat x limits
129    if isinstance(xl,int) or isinstance(xl,float):
130        xl = [xl for i in range(d)]
131    if isinstance(xu,int) or isinstance(xu,float):
132        xu = [xu for i in range(d)]
133
134    #Sensor data
135    if train:
136        if Ns is not None or Nts is not None:
137            if poss == 'grid':
138                #Create the grid
139                x_sensor = [jnp.linspace(xl[i],xu[i],Ns + 2)[1:-1] for i in range(d)]
140                x_sensor = jnp.meshgrid(*x_sensor, indexing='ij')
141                x_sensor = jnp.stack(x_sensor, axis=-1).reshape((-1, d))
142            else:
143                #Sample Ns^d points for the first coordinate
144                x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1))
145                for i in range(d-1):
146                    #Sample Ns^d points for the i-th coordinate and append collumn-wise
147                    x_sensor =  jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1)
148            x_sensor = jnp.array(x_sensor,dtype = jnp.float32)
149            if Nts is not None:
150                if posts == 'grid':
151                    #Create the Nt grid of (tl,tu]
152                    t_sensor = jnp.linspace(tl,tu,Nts + 1)[1:]
153                else:
154                    #Sample Nt points from (tl,tu)
155                    t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,))
156                #Product of x and t
157                xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32)
158                #Calculate u at each point
159                u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32)
160                u_sensor = u_sensor.reshape((u_sensor.shape[0],p))
161            else:
162                xt_sensor = x_sensor
163                #Calculate u at each point
164                u_sensor = u(xt_sensor)
165        else:
166            #Return None if sensor data should not be generated
167            xt_sensor = None
168            u_sensor = None
169
170        #Set collocation points (always in an interior grid)
171        if Ntc is not None or Nc is not None:
172            if Ntc is not None:
173                if postc == 'grid':
174                    #Create the Ntc grid of (tl,tu]
175                    t_collocation = jnp.linspace(tl,tu,Ntc + 1)[1:]
176                else:
177                    #Sample Ntc points from (tl,tu)
178                    t_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntc,))
179            if posc == 'grid':
180                #Create the grid
181                x_collocation = [jnp.linspace(xl[i],xu[i],Nc + 2)[1:-1] for i in range(d)]
182                x_collocation = jnp.meshgrid(*x_collocation, indexing='ij')
183                x_collocation = jnp.stack(x_collocation, axis=-1).reshape((-1, d))
184                if Ntc is not None:
185                    #Product of x and t
186                    xt_collocation = jnp.array([x.tolist() + [t.tolist()] for t in t_collocation for x in x_collocation],dtype = jnp.float32)
187                else:
188                    xt_collocation = x_collocation
189            else:
190                #Sample Nc^d points for the first coordinate
191                x_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nc ** d,1))
192                for i in range(d-1):
193                    #Sample Nc^d points for the i-th coordinate and append collumn-wise
194                    x_collocation =  jnp.append(x_collocation,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Nc ** d,1)),1)
195                if Ntc is not None:
196                    #Product of x and t
197                    xt_collocation = jnp.array([x.tolist() + [t.tolist()] for t in t_collocation for x in x_collocation],dtype = jnp.float32)
198                else:
199                    xt_collocation = x_collocation
200        else:
201            #Return None if collocation data should not be generated
202            xt_collocation = None
203
204        #Boundary data
205        if Ntb is not None or Nb is not None:
206            if Ntb is not None:
207                if postb == 'grid':
208                    #Create the Ntb grid of (tl,tu]
209                    t_boundary = jnp.linspace(tl,tu,Ntb + 1)[1:]
210                else:
211                    #Sample Ntb points from (tl,tu)
212                    t_boundary = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntb,))
213
214            #An array in which each line represents an edge of the n-cube
215            pre_grid = [[xl[0]],[xu[0]],[jnp.inf]]
216            for i in range(d - 1):
217                pre_grid = [x1 + [x2] for x1 in pre_grid for x2 in [xl[i + 1],xu[i + 1],jnp.inf]]
218            #Exclude last row
219            pre_grid = pre_grid[:-1]
220            #Create array with vertex (xl,...,xl)
221            x_boundary = jnp.array(pre_grid[0],dtype = jnp.float32).reshape((1,d))
222            if posb == 'grid':
223                #Create a grid over each edge of the n-cube
224                for i in range(len(pre_grid) - 1):
225                    if jnp.inf in pre_grid[i + 1]:
226                        #Create a list of the grid values along each coordinate in the edge i + 1
227                        grid_points = list()
228                        for j in range(len(pre_grid[i + 1])):
229                            #If the coordinate is free, create grid
230                            if pre_grid[i + 1][j] == jnp.inf:
231                                grid_points.append(jnp.linspace(xl[j],xu[j],Nb + 2)[1:-1].tolist())
232                            else:
233                                #If the coordinate is fixed, store its value
234                                grid_points.append([pre_grid[i + 1][j]])
235                        #Product of these values
236                        grid_values = [[x] for x in grid_points[0]]
237                        for j in range(len(grid_points) - 1):
238                            grid_values = [x1 + [x2] for x1 in grid_values for x2 in grid_points[j + 1]]
239                        #Append to data
240                        x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0)
241                    else:
242                        #If the point is a vertex, append it to data
243                        x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0)
244            else:
245                #Sample points over each edge of the n-cube
246                for i in range(len(pre_grid) - 1):
247                    if jnp.inf in pre_grid[i + 1]:
248                        #Product of the fixed and sampled values
249                        if jnp.inf == pre_grid[i + 1][0]:
250                            grid_values = [[x] for x in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nb,)).tolist()]
251                        else:
252                            grid_values = [[pre_grid[i + 1][0]]]
253                        for j in range(len(pre_grid[i + 1]) - 1):
254                            if jnp.inf == pre_grid[i + 1][j + 1]:
255                                grid_values = [x1 + [x2] for x1 in grid_values for x2 in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[j + 1],maxval = xu[j + 1],shape = (Nb,)).tolist()]
256                            else:
257                                grid_values = [x1 + [pre_grid[i + 1][j + 1]] for x1 in grid_values]
258                        #Append to data
259                        x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0)
260                    else:
261                        #If the point is a vertex, append it to data
262                        x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0)
263            if Ntb is not None:
264                #Product of x and t
265                xt_boundary = jnp.array([x.tolist() + [t.tolist()] for x in x_boundary for t in t_boundary],dtype = jnp.float32)
266                #Calculate u at each point
267                u_boundary = jnp.array([[u(x,t) + sigmab*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_boundary for t in t_boundary],dtype = jnp.float32)
268                u_boundary = u_boundary.reshape((u_boundary.shape[0],p))
269            else:
270                xt_boundary = x_boundary
271                u_boundary = u(x_boundary)
272        else:
273            #Return None if boundary data should not be generated
274            xt_boundary = None
275            u_boundary = None
276
277        #Initial data
278        if N0 is not None:
279            if pos0 == 'grid':
280                #Create the grid for the first coordinate
281                x_initial = [jnp.linspace(xl[i],xu[i],N0) for i in range(d)]
282                x_initial = jnp.meshgrid(*x_initial, indexing='ij')
283                x_initial = jnp.stack(x_initial, axis=-1).reshape((-1, d))
284            else:
285                #Sample N0^d points for the first coordinate
286                x_initial = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (N0 ** d,1))
287                for i in range(d-1):
288                    #Sample N0^d points for the i-th coordinate and append collumn-wise
289                    x_initial =  jnp.append(x_initial,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (N0 ** d,1)),1)
290            x_initial = jnp.array(x_initial,dtype = jnp.float32)
291            #Product of x and t
292            xt_initial = jnp.array([x.tolist() + [t] for x in x_initial for t in [0.0]],dtype = jnp.float32)
293            #Calculate u at each point
294            u_initial = jnp.array([[u(x,t) + sigma0*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_initial for t in jnp.array([0.0])],dtype = jnp.float32)
295            u_initial = u_initial.reshape((u_initial.shape[0],p))
296        else:
297            #Return None if initial data should not be generated
298            xt_initial = None
299            u_initial = None
300    else:
301        if Ns is not None or Nts is not None:
302            if poss == 'grid':
303                #Create the grid
304                x_sensor = [jnp.linspace(xl[i],xu[i],Ns + 2)[1:-1] for i in range(d)]
305                x_sensor = jnp.meshgrid(*x_sensor, indexing='ij')
306                x_sensor = jnp.stack(x_sensor, axis=-1).reshape((-1, d))
307            else:
308                #Sample Ns^d points for the first coordinate
309                x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1))
310                for i in range(d-1):
311                    #Sample Ns^d points for the i-th coordinate and append collumn-wise
312                    x_sensor =  jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1)
313            x_sensor = jnp.array(x_sensor,dtype = jnp.float32)
314            if Nts is not None:
315                if posts == 'grid':
316                    #Create the Nt grid of (tl,tu]
317                    t_sensor = jnp.linspace(tl,tu,Nts)
318                else:
319                    #Sample Nt points from (tl,tu)
320                    t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,))
321                #Product of x and t
322                xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32)
323                #Calculate u at each point
324                u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32)
325                u_sensor = u_sensor.reshape((u_sensor.shape[0],p))
326            else:
327                xt_sensor = x_sensor
328                u_sensor = u(x_sensor)
329        else:
330            #Return None if sensor data should not be generated
331            xt_sensor = None
332            u_sensor = None
333
334    #Create data structure
335    if train:
336        dat = {'sensor': xt_sensor,'usensor': u_sensor,'boundary': xt_boundary,'uboundary': u_boundary,'initial': xt_initial,'uinitial': u_initial,'collocation': xt_collocation}
337    else:
338        dat = {'xt': xt_sensor,'u': u_sensor}
339
340    return dat
341
342#Read and organize a data.frame
343def read_data_frame(file,sep = None,header = 'infer',sheet = 0):
344    """
345    Read a data file and convert to JAX array.
346    -------
347
348    Parameters
349    ----------
350    file : str
351
352        File name with extension .csv, .txt, .xls or .xlsx
353
354    sep : str
355
356        Separation character for .csv and .txt files. Default ',' for .csv and ' ' for .txt
357
358    header : int, Sequence of int, ‘infer’ or None
359
360        See pandas.read_csv documentation. Default 'infer'
361
362    sheet : int
363
364        Sheet number for .xls and .xlsx files. Default 0
365
366    Returns
367    -------
368
369    a JAX numpy array
370
371    """
372
373    #Find out data extension
374    ext = file.split('.')[1]
375
376    #Read data frame
377    if ext == 'csv':
378        if sep is None:
379            sep = ','
380        dat = pandas.read_csv(file,sep = sep,header = header)
381    elif ext == 'txt':
382        if sep is None:
383            sep = ' '
384        dat = pandas.read_table(file,sep = sep,header = header)
385    elif ext == 'xls' or ext == 'xlsx':
386        dat = pandas.read_excel(file,header = header,sheet_name = sheet)
387
388    #Convert to JAX data structure
389    dat = jnp.array(dat,dtype = jnp.float32)
390
391    return dat
def generate_PINNdata( u, xl, xu, tl=None, tu=None, Ns=None, Nts=None, Nb=None, Ntb=None, N0=None, Nc=None, Ntc=None, train=True, d=1, p=1, poss='grid', posts='grid', pos0='grid', posb='grid', postb='grid', posc='grid', postc='grid', sigmas=0, sigmab=0, sigma0=0):
 15def generate_PINNdata(u,xl,xu,tl = None,tu = None,Ns = None,Nts = None,Nb = None,Ntb = None,N0 = None,Nc = None,Ntc = None,train = True,d = 1,p = 1,poss = 'grid',posts = 'grid',pos0 = 'grid',posb = 'grid',postb = 'grid',posc = 'grid',postc = 'grid',sigmas = 0,sigmab = 0,sigma0 = 0):
 16    """
 17    Generate spatio-temporal data in a d-dimensional cube for PINN simulation
 18    ----------
 19
 20    Parameters
 21    ----------
 22    u : function
 23
 24        The function u(x,t) solution of the PDE
 25
 26    xl : float
 27
 28        Lower bound of each x coordinate
 29
 30    xu : float
 31
 32        Upper bound of each x coordinate
 33
 34    tl : float
 35
 36        Lower bound of the time interval
 37
 38    tu : float
 39
 40        Upper bound of the time interval
 41
 42    Ns : int, None
 43
 44        Number of points along each x coordinate for sensor data. None for not generating sensor data
 45
 46    Nts : int, None
 47
 48        Number of points along the time axis for sensor data. None for not generating sensor data
 49
 50    Nb : int, None
 51
 52        Number of points along each x coordinate for boundary data. None for not generating boundary data
 53
 54    Ntb : int, None
 55
 56        Number of points along the time axis for boundary data. None for not generating boundary data
 57
 58    N0 : int, None
 59
 60        Number of points along each x coordinate for initial data. None for not generating initial data
 61
 62    Nc : int, None
 63
 64        Number of points along each x coordinate for collocation points. None for not generating collocation points
 65
 66    Ntc : int, None
 67
 68        Number of points along the time axis for collocation points. None for not generating collocation points
 69
 70    train : logical
 71
 72        Whether to generate train (True) or test (False) data. Only sensor data is generated for test data. Default True
 73
 74    d : int
 75
 76        Domain dimension. Default 1
 77
 78    p : int
 79
 80        Output dimension. Default 1
 81
 82    poss : str
 83
 84        Position of sensor data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 85
 86    posts : str
 87
 88        Position of sensor data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 89
 90    posb : int
 91
 92        Position of boundary data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 93
 94    postb : int
 95
 96        Position of boundary data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
 97
 98    pos0 : int
 99
100        Position of initial data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
101
102    posc : str
103
104        Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
105
106    postc : str
107
108        Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
109
110    sigmas : str
111
112        Standard deviation of the Gaussian noise of sensor data. Default 0
113
114    sigmab : str
115
116        Standard deviation of the Gaussian noise of boundary data. Default 0
117
118    sigma0 : str
119
120        Standard deviation of the Gaussian noise of initial data. Default 0
121
122    Returns
123    -------
124
125    dict-like object with generated data
126
127    """
128
129    #Repeat x limits
130    if isinstance(xl,int) or isinstance(xl,float):
131        xl = [xl for i in range(d)]
132    if isinstance(xu,int) or isinstance(xu,float):
133        xu = [xu for i in range(d)]
134
135    #Sensor data
136    if train:
137        if Ns is not None or Nts is not None:
138            if poss == 'grid':
139                #Create the grid
140                x_sensor = [jnp.linspace(xl[i],xu[i],Ns + 2)[1:-1] for i in range(d)]
141                x_sensor = jnp.meshgrid(*x_sensor, indexing='ij')
142                x_sensor = jnp.stack(x_sensor, axis=-1).reshape((-1, d))
143            else:
144                #Sample Ns^d points for the first coordinate
145                x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1))
146                for i in range(d-1):
147                    #Sample Ns^d points for the i-th coordinate and append collumn-wise
148                    x_sensor =  jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1)
149            x_sensor = jnp.array(x_sensor,dtype = jnp.float32)
150            if Nts is not None:
151                if posts == 'grid':
152                    #Create the Nt grid of (tl,tu]
153                    t_sensor = jnp.linspace(tl,tu,Nts + 1)[1:]
154                else:
155                    #Sample Nt points from (tl,tu)
156                    t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,))
157                #Product of x and t
158                xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32)
159                #Calculate u at each point
160                u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32)
161                u_sensor = u_sensor.reshape((u_sensor.shape[0],p))
162            else:
163                xt_sensor = x_sensor
164                #Calculate u at each point
165                u_sensor = u(xt_sensor)
166        else:
167            #Return None if sensor data should not be generated
168            xt_sensor = None
169            u_sensor = None
170
171        #Set collocation points (always in an interior grid)
172        if Ntc is not None or Nc is not None:
173            if Ntc is not None:
174                if postc == 'grid':
175                    #Create the Ntc grid of (tl,tu]
176                    t_collocation = jnp.linspace(tl,tu,Ntc + 1)[1:]
177                else:
178                    #Sample Ntc points from (tl,tu)
179                    t_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntc,))
180            if posc == 'grid':
181                #Create the grid
182                x_collocation = [jnp.linspace(xl[i],xu[i],Nc + 2)[1:-1] for i in range(d)]
183                x_collocation = jnp.meshgrid(*x_collocation, indexing='ij')
184                x_collocation = jnp.stack(x_collocation, axis=-1).reshape((-1, d))
185                if Ntc is not None:
186                    #Product of x and t
187                    xt_collocation = jnp.array([x.tolist() + [t.tolist()] for t in t_collocation for x in x_collocation],dtype = jnp.float32)
188                else:
189                    xt_collocation = x_collocation
190            else:
191                #Sample Nc^d points for the first coordinate
192                x_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nc ** d,1))
193                for i in range(d-1):
194                    #Sample Nc^d points for the i-th coordinate and append collumn-wise
195                    x_collocation =  jnp.append(x_collocation,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Nc ** d,1)),1)
196                if Ntc is not None:
197                    #Product of x and t
198                    xt_collocation = jnp.array([x.tolist() + [t.tolist()] for t in t_collocation for x in x_collocation],dtype = jnp.float32)
199                else:
200                    xt_collocation = x_collocation
201        else:
202            #Return None if collocation data should not be generated
203            xt_collocation = None
204
205        #Boundary data
206        if Ntb is not None or Nb is not None:
207            if Ntb is not None:
208                if postb == 'grid':
209                    #Create the Ntb grid of (tl,tu]
210                    t_boundary = jnp.linspace(tl,tu,Ntb + 1)[1:]
211                else:
212                    #Sample Ntb points from (tl,tu)
213                    t_boundary = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntb,))
214
215            #An array in which each line represents an edge of the n-cube
216            pre_grid = [[xl[0]],[xu[0]],[jnp.inf]]
217            for i in range(d - 1):
218                pre_grid = [x1 + [x2] for x1 in pre_grid for x2 in [xl[i + 1],xu[i + 1],jnp.inf]]
219            #Exclude last row
220            pre_grid = pre_grid[:-1]
221            #Create array with vertex (xl,...,xl)
222            x_boundary = jnp.array(pre_grid[0],dtype = jnp.float32).reshape((1,d))
223            if posb == 'grid':
224                #Create a grid over each edge of the n-cube
225                for i in range(len(pre_grid) - 1):
226                    if jnp.inf in pre_grid[i + 1]:
227                        #Create a list of the grid values along each coordinate in the edge i + 1
228                        grid_points = list()
229                        for j in range(len(pre_grid[i + 1])):
230                            #If the coordinate is free, create grid
231                            if pre_grid[i + 1][j] == jnp.inf:
232                                grid_points.append(jnp.linspace(xl[j],xu[j],Nb + 2)[1:-1].tolist())
233                            else:
234                                #If the coordinate is fixed, store its value
235                                grid_points.append([pre_grid[i + 1][j]])
236                        #Product of these values
237                        grid_values = [[x] for x in grid_points[0]]
238                        for j in range(len(grid_points) - 1):
239                            grid_values = [x1 + [x2] for x1 in grid_values for x2 in grid_points[j + 1]]
240                        #Append to data
241                        x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0)
242                    else:
243                        #If the point is a vertex, append it to data
244                        x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0)
245            else:
246                #Sample points over each edge of the n-cube
247                for i in range(len(pre_grid) - 1):
248                    if jnp.inf in pre_grid[i + 1]:
249                        #Product of the fixed and sampled values
250                        if jnp.inf == pre_grid[i + 1][0]:
251                            grid_values = [[x] for x in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nb,)).tolist()]
252                        else:
253                            grid_values = [[pre_grid[i + 1][0]]]
254                        for j in range(len(pre_grid[i + 1]) - 1):
255                            if jnp.inf == pre_grid[i + 1][j + 1]:
256                                grid_values = [x1 + [x2] for x1 in grid_values for x2 in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[j + 1],maxval = xu[j + 1],shape = (Nb,)).tolist()]
257                            else:
258                                grid_values = [x1 + [pre_grid[i + 1][j + 1]] for x1 in grid_values]
259                        #Append to data
260                        x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0)
261                    else:
262                        #If the point is a vertex, append it to data
263                        x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0)
264            if Ntb is not None:
265                #Product of x and t
266                xt_boundary = jnp.array([x.tolist() + [t.tolist()] for x in x_boundary for t in t_boundary],dtype = jnp.float32)
267                #Calculate u at each point
268                u_boundary = jnp.array([[u(x,t) + sigmab*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_boundary for t in t_boundary],dtype = jnp.float32)
269                u_boundary = u_boundary.reshape((u_boundary.shape[0],p))
270            else:
271                xt_boundary = x_boundary
272                u_boundary = u(x_boundary)
273        else:
274            #Return None if boundary data should not be generated
275            xt_boundary = None
276            u_boundary = None
277
278        #Initial data
279        if N0 is not None:
280            if pos0 == 'grid':
281                #Create the grid for the first coordinate
282                x_initial = [jnp.linspace(xl[i],xu[i],N0) for i in range(d)]
283                x_initial = jnp.meshgrid(*x_initial, indexing='ij')
284                x_initial = jnp.stack(x_initial, axis=-1).reshape((-1, d))
285            else:
286                #Sample N0^d points for the first coordinate
287                x_initial = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (N0 ** d,1))
288                for i in range(d-1):
289                    #Sample N0^d points for the i-th coordinate and append collumn-wise
290                    x_initial =  jnp.append(x_initial,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (N0 ** d,1)),1)
291            x_initial = jnp.array(x_initial,dtype = jnp.float32)
292            #Product of x and t
293            xt_initial = jnp.array([x.tolist() + [t] for x in x_initial for t in [0.0]],dtype = jnp.float32)
294            #Calculate u at each point
295            u_initial = jnp.array([[u(x,t) + sigma0*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_initial for t in jnp.array([0.0])],dtype = jnp.float32)
296            u_initial = u_initial.reshape((u_initial.shape[0],p))
297        else:
298            #Return None if initial data should not be generated
299            xt_initial = None
300            u_initial = None
301    else:
302        if Ns is not None or Nts is not None:
303            if poss == 'grid':
304                #Create the grid
305                x_sensor = [jnp.linspace(xl[i],xu[i],Ns + 2)[1:-1] for i in range(d)]
306                x_sensor = jnp.meshgrid(*x_sensor, indexing='ij')
307                x_sensor = jnp.stack(x_sensor, axis=-1).reshape((-1, d))
308            else:
309                #Sample Ns^d points for the first coordinate
310                x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1))
311                for i in range(d-1):
312                    #Sample Ns^d points for the i-th coordinate and append collumn-wise
313                    x_sensor =  jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1)
314            x_sensor = jnp.array(x_sensor,dtype = jnp.float32)
315            if Nts is not None:
316                if posts == 'grid':
317                    #Create the Nt grid of (tl,tu]
318                    t_sensor = jnp.linspace(tl,tu,Nts)
319                else:
320                    #Sample Nt points from (tl,tu)
321                    t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,))
322                #Product of x and t
323                xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32)
324                #Calculate u at each point
325                u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32)
326                u_sensor = u_sensor.reshape((u_sensor.shape[0],p))
327            else:
328                xt_sensor = x_sensor
329                u_sensor = u(x_sensor)
330        else:
331            #Return None if sensor data should not be generated
332            xt_sensor = None
333            u_sensor = None
334
335    #Create data structure
336    if train:
337        dat = {'sensor': xt_sensor,'usensor': u_sensor,'boundary': xt_boundary,'uboundary': u_boundary,'initial': xt_initial,'uinitial': u_initial,'collocation': xt_collocation}
338    else:
339        dat = {'xt': xt_sensor,'u': u_sensor}
340
341    return dat

Generate spatio-temporal data in a d-dimensional cube for PINN simulation

Parameters
  • u (function): The function u(x,t) solution of the PDE
  • xl (float): Lower bound of each x coordinate
  • xu (float): Upper bound of each x coordinate
  • tl (float): Lower bound of the time interval
  • tu (float): Upper bound of the time interval
  • Ns (int, None): Number of points along each x coordinate for sensor data. None for not generating sensor data
  • Nts (int, None): Number of points along the time axis for sensor data. None for not generating sensor data
  • Nb (int, None): Number of points along each x coordinate for boundary data. None for not generating boundary data
  • Ntb (int, None): Number of points along the time axis for boundary data. None for not generating boundary data
  • N0 (int, None): Number of points along each x coordinate for initial data. None for not generating initial data
  • Nc (int, None): Number of points along each x coordinate for collocation points. None for not generating collocation points
  • Ntc (int, None): Number of points along the time axis for collocation points. None for not generating collocation points
  • train (logical): Whether to generate train (True) or test (False) data. Only sensor data is generated for test data. Default True
  • d (int): Domain dimension. Default 1
  • p (int): Output dimension. Default 1
  • poss (str): Position of sensor data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • posts (str): Position of sensor data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • posb (int): Position of boundary data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • postb (int): Position of boundary data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • pos0 (int): Position of initial data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • posc (str): Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • postc (str): Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
  • sigmas (str): Standard deviation of the Gaussian noise of sensor data. Default 0
  • sigmab (str): Standard deviation of the Gaussian noise of boundary data. Default 0
  • sigma0 (str): Standard deviation of the Gaussian noise of initial data. Default 0
Returns
  • dict-like object with generated data
def read_data_frame(file, sep=None, header='infer', sheet=0):
344def read_data_frame(file,sep = None,header = 'infer',sheet = 0):
345    """
346    Read a data file and convert to JAX array.
347    -------
348
349    Parameters
350    ----------
351    file : str
352
353        File name with extension .csv, .txt, .xls or .xlsx
354
355    sep : str
356
357        Separation character for .csv and .txt files. Default ',' for .csv and ' ' for .txt
358
359    header : int, Sequence of int, ‘infer’ or None
360
361        See pandas.read_csv documentation. Default 'infer'
362
363    sheet : int
364
365        Sheet number for .xls and .xlsx files. Default 0
366
367    Returns
368    -------
369
370    a JAX numpy array
371
372    """
373
374    #Find out data extension
375    ext = file.split('.')[1]
376
377    #Read data frame
378    if ext == 'csv':
379        if sep is None:
380            sep = ','
381        dat = pandas.read_csv(file,sep = sep,header = header)
382    elif ext == 'txt':
383        if sep is None:
384            sep = ' '
385        dat = pandas.read_table(file,sep = sep,header = header)
386    elif ext == 'xls' or ext == 'xlsx':
387        dat = pandas.read_excel(file,header = header,sheet_name = sheet)
388
389    #Convert to JAX data structure
390    dat = jnp.array(dat,dtype = jnp.float32)
391
392    return dat

Read a data file and convert to JAX array.

Parameters
  • file (str): File name with extension .csv, .txt, .xls or .xlsx
  • sep (str): Separation character for .csv and .txt files. Default ',' for .csv and ' ' for .txt
  • header (int, Sequence of int, ‘infer’ or None): See pandas.read_csv documentation. Default 'infer'
  • sheet (int): Sheet number for .xls and .xlsx files. Default 0
Returns
  • a JAX numpy array