jinnax.data
1#Functions to process data for training 2import pandas 3import jax 4import jax.numpy as jnp 5from jax import random 6import numpy as np 7import random 8import sys 9from PIL import Image 10from IPython.display import display 11__docformat__ = "numpy" 12 13#Generate d-dimensional data for PINN training 14def generate_PINNdata(u,xl,xu,tl = None,tu = None,Ns = None,Nts = None,Nb = None,Ntb = None,N0 = None,Nc = None,Ntc = None,train = True,d = 1,p = 1,poss = 'grid',posts = 'grid',pos0 = 'grid',posb = 'grid',postb = 'grid',posc = 'grid',postc = 'grid',sigmas = 0,sigmab = 0,sigma0 = 0): 15 """ 16 Generate spatio-temporal data in a d-dimensional cube for PINN simulation 17 ---------- 18 19 Parameters 20 ---------- 21 u : function 22 23 The function u(x,t) solution of the PDE 24 25 xl : float 26 27 Lower bound of each x coordinate 28 29 xu : float 30 31 Upper bound of each x coordinate 32 33 tl : float 34 35 Lower bound of the time interval 36 37 tu : float 38 39 Upper bound of the time interval 40 41 Ns : int, None 42 43 Number of points along each x coordinate for sensor data. None for not generating sensor data 44 45 Nts : int, None 46 47 Number of points along the time axis for sensor data. None for not generating sensor data 48 49 Nb : int, None 50 51 Number of points along each x coordinate for boundary data. None for not generating boundary data 52 53 Ntb : int, None 54 55 Number of points along the time axis for boundary data. None for not generating boundary data 56 57 N0 : int, None 58 59 Number of points along each x coordinate for initial data. None for not generating initial data 60 61 Nc : int, None 62 63 Number of points along each x coordinate for collocation points. None for not generating collocation points 64 65 Ntc : int, None 66 67 Number of points along the time axis for collocation points. None for not generating collocation points 68 69 train : logical 70 71 Whether to generate train (True) or test (False) data. Only sensor data is generated for test data. Default True 72 73 d : int 74 75 Domain dimension. Default 1 76 77 p : int 78 79 Output dimension. Default 1 80 81 poss : str 82 83 Position of sensor data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 84 85 posts : str 86 87 Position of sensor data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid' 88 89 posb : int 90 91 Position of boundary data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 92 93 postb : int 94 95 Position of boundary data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid' 96 97 pos0 : int 98 99 Position of initial data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 100 101 posc : str 102 103 Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 104 105 postc : str 106 107 Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid' 108 109 sigmas : str 110 111 Standard deviation of the Gaussian noise of sensor data. Default 0 112 113 sigmab : str 114 115 Standard deviation of the Gaussian noise of boundary data. Default 0 116 117 sigma0 : str 118 119 Standard deviation of the Gaussian noise of initial data. Default 0 120 121 Returns 122 ------- 123 124 dict-like object with generated data 125 126 """ 127 128 #Repeat x limits 129 if isinstance(xl,int) or isinstance(xl,float): 130 xl = [xl for i in range(d)] 131 if isinstance(xu,int) or isinstance(xu,float): 132 xu = [xu for i in range(d)] 133 134 #Sensor data 135 if train: 136 if Ns is not None or Nts is not None: 137 if poss == 'grid': 138 #Create the grid 139 x_sensor = [jnp.linspace(xl[i],xu[i],Ns + 2)[1:-1] for i in range(d)] 140 x_sensor = jnp.meshgrid(*x_sensor, indexing='ij') 141 x_sensor = jnp.stack(x_sensor, axis=-1).reshape((-1, d)) 142 else: 143 #Sample Ns^d points for the first coordinate 144 x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1)) 145 for i in range(d-1): 146 #Sample Ns^d points for the i-th coordinate and append collumn-wise 147 x_sensor = jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1) 148 x_sensor = jnp.array(x_sensor,dtype = jnp.float32) 149 if Nts is not None: 150 if posts == 'grid': 151 #Create the Nt grid of (tl,tu] 152 t_sensor = jnp.linspace(tl,tu,Nts + 1)[1:] 153 else: 154 #Sample Nt points from (tl,tu) 155 t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,)) 156 #Product of x and t 157 xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32) 158 #Calculate u at each point 159 u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32) 160 u_sensor = u_sensor.reshape((u_sensor.shape[0],p)) 161 else: 162 xt_sensor = x_sensor 163 #Calculate u at each point 164 u_sensor = u(xt_sensor) 165 else: 166 #Return None if sensor data should not be generated 167 xt_sensor = None 168 u_sensor = None 169 170 #Set collocation points (always in an interior grid) 171 if Ntc is not None or Nc is not None: 172 if Ntc is not None: 173 if postc == 'grid': 174 #Create the Ntc grid of (tl,tu] 175 t_collocation = jnp.linspace(tl,tu,Ntc + 1)[1:] 176 else: 177 #Sample Ntc points from (tl,tu) 178 t_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntc,)) 179 if posc == 'grid': 180 #Create the grid 181 x_collocation = [jnp.linspace(xl[i],xu[i],Nc + 2)[1:-1] for i in range(d)] 182 x_collocation = jnp.meshgrid(*x_collocation, indexing='ij') 183 x_collocation = jnp.stack(x_collocation, axis=-1).reshape((-1, d)) 184 if Ntc is not None: 185 #Product of x and t 186 xt_collocation = jnp.array([x.tolist() + [t.tolist()] for t in t_collocation for x in x_collocation],dtype = jnp.float32) 187 else: 188 xt_collocation = x_collocation 189 else: 190 #Sample Nc^d points for the first coordinate 191 x_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nc ** d,1)) 192 for i in range(d-1): 193 #Sample Nc^d points for the i-th coordinate and append collumn-wise 194 x_collocation = jnp.append(x_collocation,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Nc ** d,1)),1) 195 if Ntc is not None: 196 #Product of x and t 197 xt_collocation = jnp.array([x.tolist() + [t.tolist()] for t in t_collocation for x in x_collocation],dtype = jnp.float32) 198 else: 199 xt_collocation = x_collocation 200 else: 201 #Return None if collocation data should not be generated 202 xt_collocation = None 203 204 #Boundary data 205 if Ntb is not None or Nb is not None: 206 if Ntb is not None: 207 if postb == 'grid': 208 #Create the Ntb grid of (tl,tu] 209 t_boundary = jnp.linspace(tl,tu,Ntb + 1)[1:] 210 else: 211 #Sample Ntb points from (tl,tu) 212 t_boundary = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntb,)) 213 214 #An array in which each line represents an edge of the n-cube 215 pre_grid = [[xl[0]],[xu[0]],[jnp.inf]] 216 for i in range(d - 1): 217 pre_grid = [x1 + [x2] for x1 in pre_grid for x2 in [xl[i + 1],xu[i + 1],jnp.inf]] 218 #Exclude last row 219 pre_grid = pre_grid[:-1] 220 #Create array with vertex (xl,...,xl) 221 x_boundary = jnp.array(pre_grid[0],dtype = jnp.float32).reshape((1,d)) 222 if posb == 'grid': 223 #Create a grid over each edge of the n-cube 224 for i in range(len(pre_grid) - 1): 225 if jnp.inf in pre_grid[i + 1]: 226 #Create a list of the grid values along each coordinate in the edge i + 1 227 grid_points = list() 228 for j in range(len(pre_grid[i + 1])): 229 #If the coordinate is free, create grid 230 if pre_grid[i + 1][j] == jnp.inf: 231 grid_points.append(jnp.linspace(xl[j],xu[j],Nb + 2)[1:-1].tolist()) 232 else: 233 #If the coordinate is fixed, store its value 234 grid_points.append([pre_grid[i + 1][j]]) 235 #Product of these values 236 grid_values = [[x] for x in grid_points[0]] 237 for j in range(len(grid_points) - 1): 238 grid_values = [x1 + [x2] for x1 in grid_values for x2 in grid_points[j + 1]] 239 #Append to data 240 x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0) 241 else: 242 #If the point is a vertex, append it to data 243 x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0) 244 else: 245 #Sample points over each edge of the n-cube 246 for i in range(len(pre_grid) - 1): 247 if jnp.inf in pre_grid[i + 1]: 248 #Product of the fixed and sampled values 249 if jnp.inf == pre_grid[i + 1][0]: 250 grid_values = [[x] for x in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nb,)).tolist()] 251 else: 252 grid_values = [[pre_grid[i + 1][0]]] 253 for j in range(len(pre_grid[i + 1]) - 1): 254 if jnp.inf == pre_grid[i + 1][j + 1]: 255 grid_values = [x1 + [x2] for x1 in grid_values for x2 in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[j + 1],maxval = xu[j + 1],shape = (Nb,)).tolist()] 256 else: 257 grid_values = [x1 + [pre_grid[i + 1][j + 1]] for x1 in grid_values] 258 #Append to data 259 x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0) 260 else: 261 #If the point is a vertex, append it to data 262 x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0) 263 if Ntb is not None: 264 #Product of x and t 265 xt_boundary = jnp.array([x.tolist() + [t.tolist()] for x in x_boundary for t in t_boundary],dtype = jnp.float32) 266 #Calculate u at each point 267 u_boundary = jnp.array([[u(x,t) + sigmab*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_boundary for t in t_boundary],dtype = jnp.float32) 268 u_boundary = u_boundary.reshape((u_boundary.shape[0],p)) 269 else: 270 xt_boundary = x_boundary 271 u_boundary = u(x_boundary) 272 else: 273 #Return None if boundary data should not be generated 274 xt_boundary = None 275 u_boundary = None 276 277 #Initial data 278 if N0 is not None: 279 if pos0 == 'grid': 280 #Create the grid for the first coordinate 281 x_initial = [jnp.linspace(xl[i],xu[i],N0) for i in range(d)] 282 x_initial = jnp.meshgrid(*x_initial, indexing='ij') 283 x_initial = jnp.stack(x_initial, axis=-1).reshape((-1, d)) 284 else: 285 #Sample N0^d points for the first coordinate 286 x_initial = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (N0 ** d,1)) 287 for i in range(d-1): 288 #Sample N0^d points for the i-th coordinate and append collumn-wise 289 x_initial = jnp.append(x_initial,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (N0 ** d,1)),1) 290 x_initial = jnp.array(x_initial,dtype = jnp.float32) 291 #Product of x and t 292 xt_initial = jnp.array([x.tolist() + [t] for x in x_initial for t in [0.0]],dtype = jnp.float32) 293 #Calculate u at each point 294 u_initial = jnp.array([[u(x,t) + sigma0*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_initial for t in jnp.array([0.0])],dtype = jnp.float32) 295 u_initial = u_initial.reshape((u_initial.shape[0],p)) 296 else: 297 #Return None if initial data should not be generated 298 xt_initial = None 299 u_initial = None 300 else: 301 if Ns is not None or Nts is not None: 302 if poss == 'grid': 303 #Create the grid 304 x_sensor = [jnp.linspace(xl[i],xu[i],Ns + 2)[1:-1] for i in range(d)] 305 x_sensor = jnp.meshgrid(*x_sensor, indexing='ij') 306 x_sensor = jnp.stack(x_sensor, axis=-1).reshape((-1, d)) 307 else: 308 #Sample Ns^d points for the first coordinate 309 x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1)) 310 for i in range(d-1): 311 #Sample Ns^d points for the i-th coordinate and append collumn-wise 312 x_sensor = jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1) 313 x_sensor = jnp.array(x_sensor,dtype = jnp.float32) 314 if Nts is not None: 315 if posts == 'grid': 316 #Create the Nt grid of (tl,tu] 317 t_sensor = jnp.linspace(tl,tu,Nts) 318 else: 319 #Sample Nt points from (tl,tu) 320 t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,)) 321 #Product of x and t 322 xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32) 323 #Calculate u at each point 324 u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32) 325 u_sensor = u_sensor.reshape((u_sensor.shape[0],p)) 326 else: 327 xt_sensor = x_sensor 328 u_sensor = u(x_sensor) 329 else: 330 #Return None if sensor data should not be generated 331 xt_sensor = None 332 u_sensor = None 333 334 #Create data structure 335 if train: 336 dat = {'sensor': xt_sensor,'usensor': u_sensor,'boundary': xt_boundary,'uboundary': u_boundary,'initial': xt_initial,'uinitial': u_initial,'collocation': xt_collocation} 337 else: 338 dat = {'xt': xt_sensor,'u': u_sensor} 339 340 return dat 341 342#Read and organize a data.frame 343def read_data_frame(file,sep = None,header = 'infer',sheet = 0): 344 """ 345 Read a data file and convert to JAX array. 346 ------- 347 348 Parameters 349 ---------- 350 file : str 351 352 File name with extension .csv, .txt, .xls or .xlsx 353 354 sep : str 355 356 Separation character for .csv and .txt files. Default ',' for .csv and ' ' for .txt 357 358 header : int, Sequence of int, ‘infer’ or None 359 360 See pandas.read_csv documentation. Default 'infer' 361 362 sheet : int 363 364 Sheet number for .xls and .xlsx files. Default 0 365 366 Returns 367 ------- 368 369 a JAX numpy array 370 371 """ 372 373 #Find out data extension 374 ext = file.split('.')[1] 375 376 #Read data frame 377 if ext == 'csv': 378 if sep is None: 379 sep = ',' 380 dat = pandas.read_csv(file,sep = sep,header = header) 381 elif ext == 'txt': 382 if sep is None: 383 sep = ' ' 384 dat = pandas.read_table(file,sep = sep,header = header) 385 elif ext == 'xls' or ext == 'xlsx': 386 dat = pandas.read_excel(file,header = header,sheet_name = sheet) 387 388 #Convert to JAX data structure 389 dat = jnp.array(dat,dtype = jnp.float32) 390 391 return dat
def
generate_PINNdata( u, xl, xu, tl=None, tu=None, Ns=None, Nts=None, Nb=None, Ntb=None, N0=None, Nc=None, Ntc=None, train=True, d=1, p=1, poss='grid', posts='grid', pos0='grid', posb='grid', postb='grid', posc='grid', postc='grid', sigmas=0, sigmab=0, sigma0=0):
15def generate_PINNdata(u,xl,xu,tl = None,tu = None,Ns = None,Nts = None,Nb = None,Ntb = None,N0 = None,Nc = None,Ntc = None,train = True,d = 1,p = 1,poss = 'grid',posts = 'grid',pos0 = 'grid',posb = 'grid',postb = 'grid',posc = 'grid',postc = 'grid',sigmas = 0,sigmab = 0,sigma0 = 0): 16 """ 17 Generate spatio-temporal data in a d-dimensional cube for PINN simulation 18 ---------- 19 20 Parameters 21 ---------- 22 u : function 23 24 The function u(x,t) solution of the PDE 25 26 xl : float 27 28 Lower bound of each x coordinate 29 30 xu : float 31 32 Upper bound of each x coordinate 33 34 tl : float 35 36 Lower bound of the time interval 37 38 tu : float 39 40 Upper bound of the time interval 41 42 Ns : int, None 43 44 Number of points along each x coordinate for sensor data. None for not generating sensor data 45 46 Nts : int, None 47 48 Number of points along the time axis for sensor data. None for not generating sensor data 49 50 Nb : int, None 51 52 Number of points along each x coordinate for boundary data. None for not generating boundary data 53 54 Ntb : int, None 55 56 Number of points along the time axis for boundary data. None for not generating boundary data 57 58 N0 : int, None 59 60 Number of points along each x coordinate for initial data. None for not generating initial data 61 62 Nc : int, None 63 64 Number of points along each x coordinate for collocation points. None for not generating collocation points 65 66 Ntc : int, None 67 68 Number of points along the time axis for collocation points. None for not generating collocation points 69 70 train : logical 71 72 Whether to generate train (True) or test (False) data. Only sensor data is generated for test data. Default True 73 74 d : int 75 76 Domain dimension. Default 1 77 78 p : int 79 80 Output dimension. Default 1 81 82 poss : str 83 84 Position of sensor data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 85 86 posts : str 87 88 Position of sensor data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid' 89 90 posb : int 91 92 Position of boundary data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 93 94 postb : int 95 96 Position of boundary data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid' 97 98 pos0 : int 99 100 Position of initial data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 101 102 posc : str 103 104 Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 105 106 postc : str 107 108 Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid' 109 110 sigmas : str 111 112 Standard deviation of the Gaussian noise of sensor data. Default 0 113 114 sigmab : str 115 116 Standard deviation of the Gaussian noise of boundary data. Default 0 117 118 sigma0 : str 119 120 Standard deviation of the Gaussian noise of initial data. Default 0 121 122 Returns 123 ------- 124 125 dict-like object with generated data 126 127 """ 128 129 #Repeat x limits 130 if isinstance(xl,int) or isinstance(xl,float): 131 xl = [xl for i in range(d)] 132 if isinstance(xu,int) or isinstance(xu,float): 133 xu = [xu for i in range(d)] 134 135 #Sensor data 136 if train: 137 if Ns is not None or Nts is not None: 138 if poss == 'grid': 139 #Create the grid 140 x_sensor = [jnp.linspace(xl[i],xu[i],Ns + 2)[1:-1] for i in range(d)] 141 x_sensor = jnp.meshgrid(*x_sensor, indexing='ij') 142 x_sensor = jnp.stack(x_sensor, axis=-1).reshape((-1, d)) 143 else: 144 #Sample Ns^d points for the first coordinate 145 x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1)) 146 for i in range(d-1): 147 #Sample Ns^d points for the i-th coordinate and append collumn-wise 148 x_sensor = jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1) 149 x_sensor = jnp.array(x_sensor,dtype = jnp.float32) 150 if Nts is not None: 151 if posts == 'grid': 152 #Create the Nt grid of (tl,tu] 153 t_sensor = jnp.linspace(tl,tu,Nts + 1)[1:] 154 else: 155 #Sample Nt points from (tl,tu) 156 t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,)) 157 #Product of x and t 158 xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32) 159 #Calculate u at each point 160 u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32) 161 u_sensor = u_sensor.reshape((u_sensor.shape[0],p)) 162 else: 163 xt_sensor = x_sensor 164 #Calculate u at each point 165 u_sensor = u(xt_sensor) 166 else: 167 #Return None if sensor data should not be generated 168 xt_sensor = None 169 u_sensor = None 170 171 #Set collocation points (always in an interior grid) 172 if Ntc is not None or Nc is not None: 173 if Ntc is not None: 174 if postc == 'grid': 175 #Create the Ntc grid of (tl,tu] 176 t_collocation = jnp.linspace(tl,tu,Ntc + 1)[1:] 177 else: 178 #Sample Ntc points from (tl,tu) 179 t_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntc,)) 180 if posc == 'grid': 181 #Create the grid 182 x_collocation = [jnp.linspace(xl[i],xu[i],Nc + 2)[1:-1] for i in range(d)] 183 x_collocation = jnp.meshgrid(*x_collocation, indexing='ij') 184 x_collocation = jnp.stack(x_collocation, axis=-1).reshape((-1, d)) 185 if Ntc is not None: 186 #Product of x and t 187 xt_collocation = jnp.array([x.tolist() + [t.tolist()] for t in t_collocation for x in x_collocation],dtype = jnp.float32) 188 else: 189 xt_collocation = x_collocation 190 else: 191 #Sample Nc^d points for the first coordinate 192 x_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nc ** d,1)) 193 for i in range(d-1): 194 #Sample Nc^d points for the i-th coordinate and append collumn-wise 195 x_collocation = jnp.append(x_collocation,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Nc ** d,1)),1) 196 if Ntc is not None: 197 #Product of x and t 198 xt_collocation = jnp.array([x.tolist() + [t.tolist()] for t in t_collocation for x in x_collocation],dtype = jnp.float32) 199 else: 200 xt_collocation = x_collocation 201 else: 202 #Return None if collocation data should not be generated 203 xt_collocation = None 204 205 #Boundary data 206 if Ntb is not None or Nb is not None: 207 if Ntb is not None: 208 if postb == 'grid': 209 #Create the Ntb grid of (tl,tu] 210 t_boundary = jnp.linspace(tl,tu,Ntb + 1)[1:] 211 else: 212 #Sample Ntb points from (tl,tu) 213 t_boundary = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntb,)) 214 215 #An array in which each line represents an edge of the n-cube 216 pre_grid = [[xl[0]],[xu[0]],[jnp.inf]] 217 for i in range(d - 1): 218 pre_grid = [x1 + [x2] for x1 in pre_grid for x2 in [xl[i + 1],xu[i + 1],jnp.inf]] 219 #Exclude last row 220 pre_grid = pre_grid[:-1] 221 #Create array with vertex (xl,...,xl) 222 x_boundary = jnp.array(pre_grid[0],dtype = jnp.float32).reshape((1,d)) 223 if posb == 'grid': 224 #Create a grid over each edge of the n-cube 225 for i in range(len(pre_grid) - 1): 226 if jnp.inf in pre_grid[i + 1]: 227 #Create a list of the grid values along each coordinate in the edge i + 1 228 grid_points = list() 229 for j in range(len(pre_grid[i + 1])): 230 #If the coordinate is free, create grid 231 if pre_grid[i + 1][j] == jnp.inf: 232 grid_points.append(jnp.linspace(xl[j],xu[j],Nb + 2)[1:-1].tolist()) 233 else: 234 #If the coordinate is fixed, store its value 235 grid_points.append([pre_grid[i + 1][j]]) 236 #Product of these values 237 grid_values = [[x] for x in grid_points[0]] 238 for j in range(len(grid_points) - 1): 239 grid_values = [x1 + [x2] for x1 in grid_values for x2 in grid_points[j + 1]] 240 #Append to data 241 x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0) 242 else: 243 #If the point is a vertex, append it to data 244 x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0) 245 else: 246 #Sample points over each edge of the n-cube 247 for i in range(len(pre_grid) - 1): 248 if jnp.inf in pre_grid[i + 1]: 249 #Product of the fixed and sampled values 250 if jnp.inf == pre_grid[i + 1][0]: 251 grid_values = [[x] for x in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nb,)).tolist()] 252 else: 253 grid_values = [[pre_grid[i + 1][0]]] 254 for j in range(len(pre_grid[i + 1]) - 1): 255 if jnp.inf == pre_grid[i + 1][j + 1]: 256 grid_values = [x1 + [x2] for x1 in grid_values for x2 in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[j + 1],maxval = xu[j + 1],shape = (Nb,)).tolist()] 257 else: 258 grid_values = [x1 + [pre_grid[i + 1][j + 1]] for x1 in grid_values] 259 #Append to data 260 x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0) 261 else: 262 #If the point is a vertex, append it to data 263 x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0) 264 if Ntb is not None: 265 #Product of x and t 266 xt_boundary = jnp.array([x.tolist() + [t.tolist()] for x in x_boundary for t in t_boundary],dtype = jnp.float32) 267 #Calculate u at each point 268 u_boundary = jnp.array([[u(x,t) + sigmab*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_boundary for t in t_boundary],dtype = jnp.float32) 269 u_boundary = u_boundary.reshape((u_boundary.shape[0],p)) 270 else: 271 xt_boundary = x_boundary 272 u_boundary = u(x_boundary) 273 else: 274 #Return None if boundary data should not be generated 275 xt_boundary = None 276 u_boundary = None 277 278 #Initial data 279 if N0 is not None: 280 if pos0 == 'grid': 281 #Create the grid for the first coordinate 282 x_initial = [jnp.linspace(xl[i],xu[i],N0) for i in range(d)] 283 x_initial = jnp.meshgrid(*x_initial, indexing='ij') 284 x_initial = jnp.stack(x_initial, axis=-1).reshape((-1, d)) 285 else: 286 #Sample N0^d points for the first coordinate 287 x_initial = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (N0 ** d,1)) 288 for i in range(d-1): 289 #Sample N0^d points for the i-th coordinate and append collumn-wise 290 x_initial = jnp.append(x_initial,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (N0 ** d,1)),1) 291 x_initial = jnp.array(x_initial,dtype = jnp.float32) 292 #Product of x and t 293 xt_initial = jnp.array([x.tolist() + [t] for x in x_initial for t in [0.0]],dtype = jnp.float32) 294 #Calculate u at each point 295 u_initial = jnp.array([[u(x,t) + sigma0*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_initial for t in jnp.array([0.0])],dtype = jnp.float32) 296 u_initial = u_initial.reshape((u_initial.shape[0],p)) 297 else: 298 #Return None if initial data should not be generated 299 xt_initial = None 300 u_initial = None 301 else: 302 if Ns is not None or Nts is not None: 303 if poss == 'grid': 304 #Create the grid 305 x_sensor = [jnp.linspace(xl[i],xu[i],Ns + 2)[1:-1] for i in range(d)] 306 x_sensor = jnp.meshgrid(*x_sensor, indexing='ij') 307 x_sensor = jnp.stack(x_sensor, axis=-1).reshape((-1, d)) 308 else: 309 #Sample Ns^d points for the first coordinate 310 x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1)) 311 for i in range(d-1): 312 #Sample Ns^d points for the i-th coordinate and append collumn-wise 313 x_sensor = jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1) 314 x_sensor = jnp.array(x_sensor,dtype = jnp.float32) 315 if Nts is not None: 316 if posts == 'grid': 317 #Create the Nt grid of (tl,tu] 318 t_sensor = jnp.linspace(tl,tu,Nts) 319 else: 320 #Sample Nt points from (tl,tu) 321 t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,)) 322 #Product of x and t 323 xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32) 324 #Calculate u at each point 325 u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32) 326 u_sensor = u_sensor.reshape((u_sensor.shape[0],p)) 327 else: 328 xt_sensor = x_sensor 329 u_sensor = u(x_sensor) 330 else: 331 #Return None if sensor data should not be generated 332 xt_sensor = None 333 u_sensor = None 334 335 #Create data structure 336 if train: 337 dat = {'sensor': xt_sensor,'usensor': u_sensor,'boundary': xt_boundary,'uboundary': u_boundary,'initial': xt_initial,'uinitial': u_initial,'collocation': xt_collocation} 338 else: 339 dat = {'xt': xt_sensor,'u': u_sensor} 340 341 return dat
Generate spatio-temporal data in a d-dimensional cube for PINN simulation
Parameters
- u (function): The function u(x,t) solution of the PDE
- xl (float): Lower bound of each x coordinate
- xu (float): Upper bound of each x coordinate
- tl (float): Lower bound of the time interval
- tu (float): Upper bound of the time interval
- Ns (int, None): Number of points along each x coordinate for sensor data. None for not generating sensor data
- Nts (int, None): Number of points along the time axis for sensor data. None for not generating sensor data
- Nb (int, None): Number of points along each x coordinate for boundary data. None for not generating boundary data
- Ntb (int, None): Number of points along the time axis for boundary data. None for not generating boundary data
- N0 (int, None): Number of points along each x coordinate for initial data. None for not generating initial data
- Nc (int, None): Number of points along each x coordinate for collocation points. None for not generating collocation points
- Ntc (int, None): Number of points along the time axis for collocation points. None for not generating collocation points
- train (logical): Whether to generate train (True) or test (False) data. Only sensor data is generated for test data. Default True
- d (int): Domain dimension. Default 1
- p (int): Output dimension. Default 1
- poss (str): Position of sensor data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- posts (str): Position of sensor data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- posb (int): Position of boundary data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- postb (int): Position of boundary data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- pos0 (int): Position of initial data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- posc (str): Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- postc (str): Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- sigmas (str): Standard deviation of the Gaussian noise of sensor data. Default 0
- sigmab (str): Standard deviation of the Gaussian noise of boundary data. Default 0
- sigma0 (str): Standard deviation of the Gaussian noise of initial data. Default 0
Returns
- dict-like object with generated data
def
read_data_frame(file, sep=None, header='infer', sheet=0):
344def read_data_frame(file,sep = None,header = 'infer',sheet = 0): 345 """ 346 Read a data file and convert to JAX array. 347 ------- 348 349 Parameters 350 ---------- 351 file : str 352 353 File name with extension .csv, .txt, .xls or .xlsx 354 355 sep : str 356 357 Separation character for .csv and .txt files. Default ',' for .csv and ' ' for .txt 358 359 header : int, Sequence of int, ‘infer’ or None 360 361 See pandas.read_csv documentation. Default 'infer' 362 363 sheet : int 364 365 Sheet number for .xls and .xlsx files. Default 0 366 367 Returns 368 ------- 369 370 a JAX numpy array 371 372 """ 373 374 #Find out data extension 375 ext = file.split('.')[1] 376 377 #Read data frame 378 if ext == 'csv': 379 if sep is None: 380 sep = ',' 381 dat = pandas.read_csv(file,sep = sep,header = header) 382 elif ext == 'txt': 383 if sep is None: 384 sep = ' ' 385 dat = pandas.read_table(file,sep = sep,header = header) 386 elif ext == 'xls' or ext == 'xlsx': 387 dat = pandas.read_excel(file,header = header,sheet_name = sheet) 388 389 #Convert to JAX data structure 390 dat = jnp.array(dat,dtype = jnp.float32) 391 392 return dat
Read a data file and convert to JAX array.
Parameters
- file (str): File name with extension .csv, .txt, .xls or .xlsx
- sep (str): Separation character for .csv and .txt files. Default ',' for .csv and ' ' for .txt
- header (int, Sequence of int, ‘infer’ or None): See pandas.read_csv documentation. Default 'infer'
- sheet (int): Sheet number for .xls and .xlsx files. Default 0
Returns
- a JAX numpy array