jinnax.data
1#Functions to process data for training 2import pandas 3import jax 4import jax.numpy as jnp 5from jax import random 6import numpy as np 7import random 8import sys 9from PIL import Image 10from IPython.display import display 11__docformat__ = "numpy" 12 13#Generate d-dimensional data for PINN training 14def generate_PINNdata(u,xl,xu,tl,tu,Ns = None,Nts = None,Nb = None,Ntb = None,N0 = None,Nc = None,Ntc = None,train = True,d = 1,p = 1,poss = 'grid',posts = 'grid',pos0 = 'grid',posb = 'grid',postb = 'grid',posc = 'grid',postc = 'grid',sigmas = 0,sigmab = 0,sigma0 = 0): 15 """ 16 Generate spatio-temporal data in a d-dimensional cube for PINN simulation 17 ---------- 18 19 Parameters 20 ---------- 21 u : function 22 23 The function u(x,t) solution of the PDE 24 25 xl : float 26 27 Lower bound of each x coordinate 28 29 xu : float 30 31 Upper bound of each x coordinate 32 33 tl : float 34 35 Lower bound of the time interval 36 37 tu : float 38 39 Upper bound of the time interval 40 41 Ns : int, None 42 43 Number of points along each x coordinate for sensor data. None for not generating sensor data 44 45 Nts : int, None 46 47 Number of points along the time axis for sensor data. None for not generating sensor data 48 49 Nb : int, None 50 51 Number of points along each x coordinate for boundary data. None for not generating boundary data 52 53 Ntb : int, None 54 55 Number of points along the time axis for boundary data. None for not generating boundary data 56 57 N0 : int, None 58 59 Number of points along each x coordinate for initial data. None for not generating initial data 60 61 Nc : int, None 62 63 Number of points along each x coordinate for collocation points. None for not generating collocation points 64 65 Ntc : int, None 66 67 Number of points along the time axis for collocation points. None for not generating collocation points 68 69 train : logical 70 71 Whether to generate train (True) or test (False) data. Only sensor data is generated for test data. Default True 72 73 d : int 74 75 Domain dimension. Default 1 76 77 p : int 78 79 Output dimension. Default 1 80 81 poss : str 82 83 Position of sensor data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 84 85 posts : str 86 87 Position of sensor data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid' 88 89 posb : int 90 91 Position of boundary data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 92 93 postb : int 94 95 Position of boundary data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid' 96 97 pos0 : int 98 99 Position of initial data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 100 101 posc : str 102 103 Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 104 105 postc : str 106 107 Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid' 108 109 sigmas : str 110 111 Standard deviation of the Gaussian noise of sensor data. Default 0 112 113 sigmab : str 114 115 Standard deviation of the Gaussian noise of boundary data. Default 0 116 117 sigma0 : str 118 119 Standard deviation of the Gaussian noise of initial data. Default 0 120 121 Returns 122 ------- 123 124 dict-like object with generated data 125 126 """ 127 128 #Repeat x limits 129 if isinstance(xl,int) or isinstance(xl,float): 130 xl = [xl for i in range(d)] 131 if isinstance(xu,int) or isinstance(xu,float): 132 xu = [xu for i in range(d)] 133 134 #Sensor data 135 if train: 136 if Ns is not None and Nts is not None: 137 if poss == 'grid': 138 #Create the grid for the first coordinate 139 x_sensor = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],Ns + 2)[1:-1]] 140 for i in range(d-1): 141 #Product with the grid of the i-th coordinate 142 x_sensor = [x1 + [x2.tolist()] for x1 in x_sensor for x2 in jnp.linspace(xl[i+1],xu[i+1],Ns + 2)[1:-1]] 143 else: 144 #Sample Ns^d points for the first coordinate 145 x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1)) 146 for i in range(d-1): 147 #Sample Ns^d points for the i-th coordinate and append collumn-wise 148 x_sensor = jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1) 149 x_sensor = jnp.array(x_sensor,dtype = jnp.float32) 150 151 if posts == 'grid': 152 #Create the Nt grid of (tl,tu] 153 t_sensor = jnp.linspace(tl,tu,Nts + 1)[1:] 154 else: 155 #Sample Nt points from (tl,tu) 156 t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,)) 157 #Product of x and t 158 xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32) 159 #Calculate u at each point 160 u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32) 161 u_sensor = u_sensor.reshape((u_sensor.shape[0],p)) 162 else: 163 #Return None if sensor data should not be generated 164 xt_sensor = None 165 u_sensor = None 166 167 #Set collocation points (always in an interior grid) 168 if Ntc is not None and Nc is not None: 169 if postc == 'grid': 170 #Create the Ntc grid of (tl,tu] 171 t_collocation = jnp.linspace(tl,tu,Ntc + 1)[1:] 172 else: 173 #Sample Ntc points from (tl,tu) 174 t_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntc,)) 175 176 if posc == 'grid': 177 #Create the grid for the first coordinate 178 x_collocation = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],Nc + 2)[1:-1]] 179 for i in range(d-1): 180 #Product with the grid of the i-th coordinate 181 x_collocation = [x1 + [x2.tolist()] for x1 in x_collocation for x2 in jnp.linspace(xl[i+1],xu[i+1],Nc + 2)[1:-1]] 182 #Product of x and t 183 xt_collocation = jnp.array([x + [t.tolist()] for x in x_collocation for t in t_collocation],dtype = jnp.float32) 184 else: 185 #Sample Nc^d points for the first coordinate 186 x_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nc ** d,1)) 187 for i in range(d-1): 188 #Sample Nc^d points for the i-th coordinate and append collumn-wise 189 x_collocation = jnp.append(x_collocation,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Nc ** d,1)),1) 190 #Product of x and t 191 xt_collocation = jnp.array([x.tolist() + [t.tolist()] for x in x_collocation for t in t_collocation],dtype = jnp.float32) 192 else: 193 #Return None if collocation data should not be generated 194 xt_collocation = None 195 196 #Boundary data 197 if Ntb is not None and Nb is not None: 198 if postb == 'grid': 199 #Create the Ntb grid of (tl,tu] 200 t_boundary = jnp.linspace(tl,tu,Ntb + 1)[1:] 201 else: 202 #Sample Ntb points from (tl,tu) 203 t_boundary = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntb,)) 204 205 #An array in which each line represents an edge of the n-cube 206 pre_grid = [[xl[0]],[xu[0]],[jnp.inf]] 207 for i in range(d - 1): 208 pre_grid = [x1 + [x2] for x1 in pre_grid for x2 in [xl[i + 1],xu[i + 1],jnp.inf]] 209 #Exclude last row 210 pre_grid = pre_grid[:-1] 211 #Create array with vertex (xl,...,xl) 212 x_boundary = jnp.array(pre_grid[0],dtype = jnp.float32).reshape((1,d)) 213 if posb == 'grid': 214 #Create a grid over each edge of the n-cube 215 for i in range(len(pre_grid) - 1): 216 if jnp.inf in pre_grid[i + 1]: 217 #Create a list of the grid values along each coordinate in the edge i + 1 218 grid_points = list() 219 for j in range(len(pre_grid[i + 1])): 220 #If the coordinate is free, create grid 221 if pre_grid[i + 1][j] == jnp.inf: 222 grid_points.append(jnp.linspace(xl[j],xu[j],Nb + 2)[1:-1].tolist()) 223 else: 224 #If the coordinate is fixed, store its value 225 grid_points.append([pre_grid[i + 1][j]]) 226 #Product of these values 227 grid_values = [[x] for x in grid_points[0]] 228 for j in range(len(grid_points) - 1): 229 grid_values = [x1 + [x2] for x1 in grid_values for x2 in grid_points[j + 1]] 230 #Append to data 231 x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0) 232 else: 233 #If the point is a vertex, append it to data 234 x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0) 235 else: 236 #Sample points over each edge of the n-cube 237 for i in range(len(pre_grid) - 1): 238 if jnp.inf in pre_grid[i + 1]: 239 #Product of the fixed and sampled values 240 if jnp.inf == pre_grid[i + 1][0]: 241 grid_values = [[x] for x in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nb,)).tolist()] 242 else: 243 grid_values = [[pre_grid[i + 1][0]]] 244 for j in range(len(pre_grid[i + 1]) - 1): 245 if jnp.inf == pre_grid[i + 1][j + 1]: 246 grid_values = [x1 + [x2] for x1 in grid_values for x2 in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[j + 1],maxval = xu[j + 1],shape = (Nb,)).tolist()] 247 else: 248 grid_values = [x1 + [pre_grid[i + 1][j + 1]] for x1 in grid_values] 249 #Append to data 250 x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0) 251 else: 252 #If the point is a vertex, append it to data 253 x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0) 254 #Product of x and t 255 xt_boundary = jnp.array([x.tolist() + [t.tolist()] for x in x_boundary for t in t_boundary],dtype = jnp.float32) 256 #Calculate u at each point 257 u_boundary = jnp.array([[u(x,t) + sigmab*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_boundary for t in t_boundary],dtype = jnp.float32) 258 u_boundary = u_boundary.reshape((u_boundary.shape[0],p)) 259 else: 260 #Return None if boundary data should not be generated 261 xt_boundary = None 262 u_boundary = None 263 264 #Initial data 265 if N0 is not None: 266 if pos0 == 'grid': 267 #Create the grid for the first coordinate 268 x_initial = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],N0)] 269 for i in range(d-1): 270 #Product with the grid of the i-th coordinate 271 x_initial = [x1 + [x2.tolist()] for x1 in x_initial for x2 in jnp.linspace(xl[i+1],xu[i+1],N0)] 272 else: 273 #Sample N0^d points for the first coordinate 274 x_initial = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (N0 ** d,1)) 275 for i in range(d-1): 276 #Sample N0^d points for the i-th coordinate and append collumn-wise 277 x_initial = jnp.append(x_initial,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (N0 ** d,1)),1) 278 x_initial = jnp.array(x_initial,dtype = jnp.float32) 279 280 #Product of x and t 281 xt_initial = jnp.array([x.tolist() + [t] for x in x_initial for t in [0.0]],dtype = jnp.float32) 282 #Calculate u at each point 283 u_initial = jnp.array([[u(x,t) + sigma0*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_initial for t in jnp.array([0.0])],dtype = jnp.float32) 284 u_initial = u_initial.reshape((u_initial.shape[0],p)) 285 else: 286 #Return None if initial data should not be generated 287 xt_initial = None 288 u_initial = None 289 else: 290 if Ns is not None and Nts is not None: 291 if poss == 'grid': 292 #Create the grid for the first coordinate 293 x_sensor = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],Ns)] 294 for i in range(d-1): 295 #Product with the grid of the i-th coordinate 296 x_sensor = [x1 + [x2.tolist()] for x1 in x_sensor for x2 in jnp.linspace(xl[i+1],xu[i+1],Ns + 2)[1:-1]] 297 else: 298 #Sample Ns^d points for the first coordinate 299 x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1)) 300 for i in range(d-1): 301 #Sample Ns^d points for the i-th coordinate and append collumn-wise 302 x_sensor = jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1) 303 x_sensor = jnp.array(x_sensor,dtype = jnp.float32) 304 305 if posts == 'grid': 306 #Create the Nt grid of (tl,tu] 307 t_sensor = jnp.linspace(tl,tu,Nts) 308 else: 309 #Sample Nt points from (tl,tu) 310 t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,)) 311 #Product of x and t 312 xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32) 313 #Calculate u at each point 314 u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32) 315 u_sensor = u_sensor.reshape((u_sensor.shape[0],p)) 316 else: 317 #Return None if sensor data should not be generated 318 xt_sensor = None 319 u_sensor = None 320 321 #Create data structure 322 if train: 323 dat = {'sensor': xt_sensor,'usensor': u_sensor,'boundary': xt_boundary,'uboundary': u_boundary,'initial': xt_initial,'uinitial': u_initial,'collocation': xt_collocation} 324 else: 325 dat = {'xt': xt_sensor,'u': u_sensor} 326 327 return dat 328 329#Read and organize a data.frame 330def read_data_frame(file,sep = None,header = 'infer',sheet = 0): 331 """ 332 Read a data file and convert to JAX array. 333 ------- 334 335 Parameters 336 ---------- 337 file : str 338 339 File name with extension .csv, .txt, .xls or .xlsx 340 341 sep : str 342 343 Separation character for .csv and .txt files. Default ',' for .csv and ' ' for .txt 344 345 header : int, Sequence of int, ‘infer’ or None 346 347 See pandas.read_csv documentation. Default 'infer' 348 349 sheet : int 350 351 Sheet number for .xls and .xlsx files. Default 0 352 353 Returns 354 ------- 355 356 a JAX numpy array 357 358 """ 359 360 #Find out data extension 361 ext = file.split('.')[1] 362 363 #Read data frame 364 if ext == 'csv': 365 if sep is None: 366 sep = ',' 367 dat = pandas.read_csv(file,sep = sep,header = header) 368 elif ext == 'txt': 369 if sep is None: 370 sep = ' ' 371 dat = pandas.read_table(file,sep = sep,header = header) 372 elif ext == 'xls' or ext == 'xlsx': 373 dat = pandas.read_excel(file,header = header,sheet_name = sheet) 374 375 #Convert to JAX data structure 376 dat = jnp.array(dat,dtype = jnp.float32) 377 378 return dat
def
generate_PINNdata( u, xl, xu, tl, tu, Ns=None, Nts=None, Nb=None, Ntb=None, N0=None, Nc=None, Ntc=None, train=True, d=1, p=1, poss='grid', posts='grid', pos0='grid', posb='grid', postb='grid', posc='grid', postc='grid', sigmas=0, sigmab=0, sigma0=0):
15def generate_PINNdata(u,xl,xu,tl,tu,Ns = None,Nts = None,Nb = None,Ntb = None,N0 = None,Nc = None,Ntc = None,train = True,d = 1,p = 1,poss = 'grid',posts = 'grid',pos0 = 'grid',posb = 'grid',postb = 'grid',posc = 'grid',postc = 'grid',sigmas = 0,sigmab = 0,sigma0 = 0): 16 """ 17 Generate spatio-temporal data in a d-dimensional cube for PINN simulation 18 ---------- 19 20 Parameters 21 ---------- 22 u : function 23 24 The function u(x,t) solution of the PDE 25 26 xl : float 27 28 Lower bound of each x coordinate 29 30 xu : float 31 32 Upper bound of each x coordinate 33 34 tl : float 35 36 Lower bound of the time interval 37 38 tu : float 39 40 Upper bound of the time interval 41 42 Ns : int, None 43 44 Number of points along each x coordinate for sensor data. None for not generating sensor data 45 46 Nts : int, None 47 48 Number of points along the time axis for sensor data. None for not generating sensor data 49 50 Nb : int, None 51 52 Number of points along each x coordinate for boundary data. None for not generating boundary data 53 54 Ntb : int, None 55 56 Number of points along the time axis for boundary data. None for not generating boundary data 57 58 N0 : int, None 59 60 Number of points along each x coordinate for initial data. None for not generating initial data 61 62 Nc : int, None 63 64 Number of points along each x coordinate for collocation points. None for not generating collocation points 65 66 Ntc : int, None 67 68 Number of points along the time axis for collocation points. None for not generating collocation points 69 70 train : logical 71 72 Whether to generate train (True) or test (False) data. Only sensor data is generated for test data. Default True 73 74 d : int 75 76 Domain dimension. Default 1 77 78 p : int 79 80 Output dimension. Default 1 81 82 poss : str 83 84 Position of sensor data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 85 86 posts : str 87 88 Position of sensor data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid' 89 90 posb : int 91 92 Position of boundary data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 93 94 postb : int 95 96 Position of boundary data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid' 97 98 pos0 : int 99 100 Position of initial data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 101 102 posc : str 103 104 Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid' 105 106 postc : str 107 108 Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid' 109 110 sigmas : str 111 112 Standard deviation of the Gaussian noise of sensor data. Default 0 113 114 sigmab : str 115 116 Standard deviation of the Gaussian noise of boundary data. Default 0 117 118 sigma0 : str 119 120 Standard deviation of the Gaussian noise of initial data. Default 0 121 122 Returns 123 ------- 124 125 dict-like object with generated data 126 127 """ 128 129 #Repeat x limits 130 if isinstance(xl,int) or isinstance(xl,float): 131 xl = [xl for i in range(d)] 132 if isinstance(xu,int) or isinstance(xu,float): 133 xu = [xu for i in range(d)] 134 135 #Sensor data 136 if train: 137 if Ns is not None and Nts is not None: 138 if poss == 'grid': 139 #Create the grid for the first coordinate 140 x_sensor = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],Ns + 2)[1:-1]] 141 for i in range(d-1): 142 #Product with the grid of the i-th coordinate 143 x_sensor = [x1 + [x2.tolist()] for x1 in x_sensor for x2 in jnp.linspace(xl[i+1],xu[i+1],Ns + 2)[1:-1]] 144 else: 145 #Sample Ns^d points for the first coordinate 146 x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1)) 147 for i in range(d-1): 148 #Sample Ns^d points for the i-th coordinate and append collumn-wise 149 x_sensor = jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1) 150 x_sensor = jnp.array(x_sensor,dtype = jnp.float32) 151 152 if posts == 'grid': 153 #Create the Nt grid of (tl,tu] 154 t_sensor = jnp.linspace(tl,tu,Nts + 1)[1:] 155 else: 156 #Sample Nt points from (tl,tu) 157 t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,)) 158 #Product of x and t 159 xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32) 160 #Calculate u at each point 161 u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32) 162 u_sensor = u_sensor.reshape((u_sensor.shape[0],p)) 163 else: 164 #Return None if sensor data should not be generated 165 xt_sensor = None 166 u_sensor = None 167 168 #Set collocation points (always in an interior grid) 169 if Ntc is not None and Nc is not None: 170 if postc == 'grid': 171 #Create the Ntc grid of (tl,tu] 172 t_collocation = jnp.linspace(tl,tu,Ntc + 1)[1:] 173 else: 174 #Sample Ntc points from (tl,tu) 175 t_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntc,)) 176 177 if posc == 'grid': 178 #Create the grid for the first coordinate 179 x_collocation = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],Nc + 2)[1:-1]] 180 for i in range(d-1): 181 #Product with the grid of the i-th coordinate 182 x_collocation = [x1 + [x2.tolist()] for x1 in x_collocation for x2 in jnp.linspace(xl[i+1],xu[i+1],Nc + 2)[1:-1]] 183 #Product of x and t 184 xt_collocation = jnp.array([x + [t.tolist()] for x in x_collocation for t in t_collocation],dtype = jnp.float32) 185 else: 186 #Sample Nc^d points for the first coordinate 187 x_collocation = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nc ** d,1)) 188 for i in range(d-1): 189 #Sample Nc^d points for the i-th coordinate and append collumn-wise 190 x_collocation = jnp.append(x_collocation,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Nc ** d,1)),1) 191 #Product of x and t 192 xt_collocation = jnp.array([x.tolist() + [t.tolist()] for x in x_collocation for t in t_collocation],dtype = jnp.float32) 193 else: 194 #Return None if collocation data should not be generated 195 xt_collocation = None 196 197 #Boundary data 198 if Ntb is not None and Nb is not None: 199 if postb == 'grid': 200 #Create the Ntb grid of (tl,tu] 201 t_boundary = jnp.linspace(tl,tu,Ntb + 1)[1:] 202 else: 203 #Sample Ntb points from (tl,tu) 204 t_boundary = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Ntb,)) 205 206 #An array in which each line represents an edge of the n-cube 207 pre_grid = [[xl[0]],[xu[0]],[jnp.inf]] 208 for i in range(d - 1): 209 pre_grid = [x1 + [x2] for x1 in pre_grid for x2 in [xl[i + 1],xu[i + 1],jnp.inf]] 210 #Exclude last row 211 pre_grid = pre_grid[:-1] 212 #Create array with vertex (xl,...,xl) 213 x_boundary = jnp.array(pre_grid[0],dtype = jnp.float32).reshape((1,d)) 214 if posb == 'grid': 215 #Create a grid over each edge of the n-cube 216 for i in range(len(pre_grid) - 1): 217 if jnp.inf in pre_grid[i + 1]: 218 #Create a list of the grid values along each coordinate in the edge i + 1 219 grid_points = list() 220 for j in range(len(pre_grid[i + 1])): 221 #If the coordinate is free, create grid 222 if pre_grid[i + 1][j] == jnp.inf: 223 grid_points.append(jnp.linspace(xl[j],xu[j],Nb + 2)[1:-1].tolist()) 224 else: 225 #If the coordinate is fixed, store its value 226 grid_points.append([pre_grid[i + 1][j]]) 227 #Product of these values 228 grid_values = [[x] for x in grid_points[0]] 229 for j in range(len(grid_points) - 1): 230 grid_values = [x1 + [x2] for x1 in grid_values for x2 in grid_points[j + 1]] 231 #Append to data 232 x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0) 233 else: 234 #If the point is a vertex, append it to data 235 x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0) 236 else: 237 #Sample points over each edge of the n-cube 238 for i in range(len(pre_grid) - 1): 239 if jnp.inf in pre_grid[i + 1]: 240 #Product of the fixed and sampled values 241 if jnp.inf == pre_grid[i + 1][0]: 242 grid_values = [[x] for x in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Nb,)).tolist()] 243 else: 244 grid_values = [[pre_grid[i + 1][0]]] 245 for j in range(len(pre_grid[i + 1]) - 1): 246 if jnp.inf == pre_grid[i + 1][j + 1]: 247 grid_values = [x1 + [x2] for x1 in grid_values for x2 in jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[j + 1],maxval = xu[j + 1],shape = (Nb,)).tolist()] 248 else: 249 grid_values = [x1 + [pre_grid[i + 1][j + 1]] for x1 in grid_values] 250 #Append to data 251 x_boundary = jnp.append(x_boundary,jnp.array(grid_values,dtype = jnp.float32).reshape((len(grid_values),d)),0) 252 else: 253 #If the point is a vertex, append it to data 254 x_boundary = jnp.append(x_boundary,jnp.array(pre_grid[i + 1],dtype = jnp.float32).reshape((1,d)),0) 255 #Product of x and t 256 xt_boundary = jnp.array([x.tolist() + [t.tolist()] for x in x_boundary for t in t_boundary],dtype = jnp.float32) 257 #Calculate u at each point 258 u_boundary = jnp.array([[u(x,t) + sigmab*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_boundary for t in t_boundary],dtype = jnp.float32) 259 u_boundary = u_boundary.reshape((u_boundary.shape[0],p)) 260 else: 261 #Return None if boundary data should not be generated 262 xt_boundary = None 263 u_boundary = None 264 265 #Initial data 266 if N0 is not None: 267 if pos0 == 'grid': 268 #Create the grid for the first coordinate 269 x_initial = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],N0)] 270 for i in range(d-1): 271 #Product with the grid of the i-th coordinate 272 x_initial = [x1 + [x2.tolist()] for x1 in x_initial for x2 in jnp.linspace(xl[i+1],xu[i+1],N0)] 273 else: 274 #Sample N0^d points for the first coordinate 275 x_initial = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (N0 ** d,1)) 276 for i in range(d-1): 277 #Sample N0^d points for the i-th coordinate and append collumn-wise 278 x_initial = jnp.append(x_initial,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (N0 ** d,1)),1) 279 x_initial = jnp.array(x_initial,dtype = jnp.float32) 280 281 #Product of x and t 282 xt_initial = jnp.array([x.tolist() + [t] for x in x_initial for t in [0.0]],dtype = jnp.float32) 283 #Calculate u at each point 284 u_initial = jnp.array([[u(x,t) + sigma0*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)))] for x in x_initial for t in jnp.array([0.0])],dtype = jnp.float32) 285 u_initial = u_initial.reshape((u_initial.shape[0],p)) 286 else: 287 #Return None if initial data should not be generated 288 xt_initial = None 289 u_initial = None 290 else: 291 if Ns is not None and Nts is not None: 292 if poss == 'grid': 293 #Create the grid for the first coordinate 294 x_sensor = [[x.tolist()] for x in jnp.linspace(xl[0],xu[0],Ns)] 295 for i in range(d-1): 296 #Product with the grid of the i-th coordinate 297 x_sensor = [x1 + [x2.tolist()] for x1 in x_sensor for x2 in jnp.linspace(xl[i+1],xu[i+1],Ns + 2)[1:-1]] 298 else: 299 #Sample Ns^d points for the first coordinate 300 x_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[0],maxval = xu[0],shape = (Ns ** d,1)) 301 for i in range(d-1): 302 #Sample Ns^d points for the i-th coordinate and append collumn-wise 303 x_sensor = jnp.append(x_sensor,jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = xl[i+1],maxval = xu[i+1],shape = (Ns ** d,1)),1) 304 x_sensor = jnp.array(x_sensor,dtype = jnp.float32) 305 306 if posts == 'grid': 307 #Create the Nt grid of (tl,tu] 308 t_sensor = jnp.linspace(tl,tu,Nts) 309 else: 310 #Sample Nt points from (tl,tu) 311 t_sensor = jax.random.uniform(key = jax.random.PRNGKey(random.randint(0,sys.maxsize)),minval = tl,maxval = tu,shape = (Nts,)) 312 #Product of x and t 313 xt_sensor = jnp.array([x.tolist() + [t.tolist()] for x in x_sensor for t in t_sensor],dtype = jnp.float32) 314 #Calculate u at each point 315 u_sensor = jnp.array([u(x,t) + sigmas*jax.random.normal(key = jax.random.PRNGKey(random.randint(0,sys.maxsize))) for x in x_sensor for t in t_sensor],dtype = jnp.float32) 316 u_sensor = u_sensor.reshape((u_sensor.shape[0],p)) 317 else: 318 #Return None if sensor data should not be generated 319 xt_sensor = None 320 u_sensor = None 321 322 #Create data structure 323 if train: 324 dat = {'sensor': xt_sensor,'usensor': u_sensor,'boundary': xt_boundary,'uboundary': u_boundary,'initial': xt_initial,'uinitial': u_initial,'collocation': xt_collocation} 325 else: 326 dat = {'xt': xt_sensor,'u': u_sensor} 327 328 return dat
Generate spatio-temporal data in a d-dimensional cube for PINN simulation
Parameters
- u (function): The function u(x,t) solution of the PDE
- xl (float): Lower bound of each x coordinate
- xu (float): Upper bound of each x coordinate
- tl (float): Lower bound of the time interval
- tu (float): Upper bound of the time interval
- Ns (int, None): Number of points along each x coordinate for sensor data. None for not generating sensor data
- Nts (int, None): Number of points along the time axis for sensor data. None for not generating sensor data
- Nb (int, None): Number of points along each x coordinate for boundary data. None for not generating boundary data
- Ntb (int, None): Number of points along the time axis for boundary data. None for not generating boundary data
- N0 (int, None): Number of points along each x coordinate for initial data. None for not generating initial data
- Nc (int, None): Number of points along each x coordinate for collocation points. None for not generating collocation points
- Ntc (int, None): Number of points along the time axis for collocation points. None for not generating collocation points
- train (logical): Whether to generate train (True) or test (False) data. Only sensor data is generated for test data. Default True
- d (int): Domain dimension. Default 1
- p (int): Output dimension. Default 1
- poss (str): Position of sensor data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- posts (str): Position of sensor data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- posb (int): Position of boundary data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- postb (int): Position of boundary data in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- pos0 (int): Position of initial data in spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- posc (str): Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- postc (str): Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
- sigmas (str): Standard deviation of the Gaussian noise of sensor data. Default 0
- sigmab (str): Standard deviation of the Gaussian noise of boundary data. Default 0
- sigma0 (str): Standard deviation of the Gaussian noise of initial data. Default 0
Returns
- dict-like object with generated data
def
read_data_frame(file, sep=None, header='infer', sheet=0):
331def read_data_frame(file,sep = None,header = 'infer',sheet = 0): 332 """ 333 Read a data file and convert to JAX array. 334 ------- 335 336 Parameters 337 ---------- 338 file : str 339 340 File name with extension .csv, .txt, .xls or .xlsx 341 342 sep : str 343 344 Separation character for .csv and .txt files. Default ',' for .csv and ' ' for .txt 345 346 header : int, Sequence of int, ‘infer’ or None 347 348 See pandas.read_csv documentation. Default 'infer' 349 350 sheet : int 351 352 Sheet number for .xls and .xlsx files. Default 0 353 354 Returns 355 ------- 356 357 a JAX numpy array 358 359 """ 360 361 #Find out data extension 362 ext = file.split('.')[1] 363 364 #Read data frame 365 if ext == 'csv': 366 if sep is None: 367 sep = ',' 368 dat = pandas.read_csv(file,sep = sep,header = header) 369 elif ext == 'txt': 370 if sep is None: 371 sep = ' ' 372 dat = pandas.read_table(file,sep = sep,header = header) 373 elif ext == 'xls' or ext == 'xlsx': 374 dat = pandas.read_excel(file,header = header,sheet_name = sheet) 375 376 #Convert to JAX data structure 377 dat = jnp.array(dat,dtype = jnp.float32) 378 379 return dat
Read a data file and convert to JAX array.
Parameters
- file (str): File name with extension .csv, .txt, .xls or .xlsx
- sep (str): Separation character for .csv and .txt files. Default ',' for .csv and ' ' for .txt
- header (int, Sequence of int, ‘infer’ or None): See pandas.read_csv documentation. Default 'infer'
- sheet (int): Sheet number for .xls and .xlsx files. Default 0
Returns
- a JAX numpy array