From 1324b1505bcd3028d3dc01b860c47c433f8ca054 Mon Sep 17 00:00:00 2001 From: Edward McDugald Date: Wed, 19 Feb 2025 22:48:49 -0700 Subject: [PATCH 1/4] added grl scripts --- tsunami/datasets.py | 15 + tsunami/make_grl_fig3.py | 422 +++++++++++++++++ tsunami/make_grl_fig4.py | 843 +++++++++++++++++++++++++++++++++ tsunami/quick_plot_combined.py | 92 ++++ 4 files changed, 1372 insertions(+) create mode 100644 tsunami/make_grl_fig3.py create mode 100644 tsunami/make_grl_fig4.py create mode 100644 tsunami/quick_plot_combined.py diff --git a/tsunami/datasets.py b/tsunami/datasets.py index 8965d70..7c5a13c 100644 --- a/tsunami/datasets.py +++ b/tsunami/datasets.py @@ -36,4 +36,19 @@ def SWETsunamiWdiv(fname): return wave_height, latitude, longitude, ocn_floor, divu, mask, time_idx else: return wave_height, latitude, longitude, ocn_floor, divu, mask + + +def SWETsunamiForPlotting2(fname): + fpath = "/Users/emcdugald/sparse_sens_tsunami/Data/tsunami/"+fname + data = sio.loadmat(fpath) + wave_height = data['zt'] + max_wave_height = np.max(np.abs(wave_height)) + wave_height /= max_wave_height + longitude = data['longitude'][0] + latitude = data['latitude'][0] + mask = data['ismask'][0] + times = data['data_times'] + sensors = data['sensor_locs'] + div = data['du'] + return wave_height, latitude, longitude, mask, max_wave_height, times, sensors, div \ No newline at end of file diff --git a/tsunami/make_grl_fig3.py b/tsunami/make_grl_fig3.py new file mode 100644 index 0000000..5f0dc02 --- /dev/null +++ b/tsunami/make_grl_fig3.py @@ -0,0 +1,422 @@ +import multiprocessing +multiprocessing.set_start_method("fork") +import os +import torch +import rasterio +import matplotlib.pyplot as plt +import numpy as np +from rasterio.merge import merge +from datasets import SWETsunamiForPlotting +from scipy.interpolate import griddata +import scipy.io as sio +from mpl_toolkits.axes_grid1 import make_axes_locatable +plt.rcParams['text.usetex'] = True +plt.rcParams["font.family"] = "serif" +plt.rcParams["font.size"] = 16 +plt.rcParams['xtick.labelsize'] = 16 +plt.rcParams['ytick.labelsize'] = 16 +plt.rcParams['font.weight'] = 'bold' +path_pref = os.getcwd()+"/" + +#activate conda env topos + +## data source ## +### https://topotools.cr.usgs.gov/gmted_viewer/viewer.htm ### +# +# f1 = rasterio.open(path_pref+'geodata/10N090E_20101117_gmted_mea075.tif') +# dta1 = f1.read() +# f2 = rasterio.open(path_pref+'geodata/10N120E_20101117_gmted_mea075.tif') +# dta2 = f2.read() +# f3 = rasterio.open(path_pref+'geodata/10N150E_20101117_gmted_mea075.tif') +# dta3 = f3.read() +# f4 = rasterio.open(path_pref+'geodata/10S090E_20101117_gmted_mea075.tif') +# dta4 = f4.read() +# f5 = rasterio.open(path_pref+'geodata/10S120E_20101117_gmted_mea075.tif') +# dta5 = f5.read() +# f6 = rasterio.open(path_pref+'geodata/10S150E_20101117_gmted_mea075.tif') +# dta6 = f6.read() +# f7 = rasterio.open(path_pref+'geodata/30N090E_20101117_gmted_mea075.tif') +# dta7 = f7.read() +# f8 = rasterio.open(path_pref+'geodata/30N120E_20101117_gmted_mea075.tif') +# dta8 = f8.read() +# f9 = rasterio.open(path_pref+'geodata/30N150E_20101117_gmted_mea075.tif') +# dta9 = f9.read() +# f10 = rasterio.open(path_pref+'geodata/50N090E_20101117_gmted_mea075.tif') +# dta10 = f10.read() +# f11 = rasterio.open(path_pref+'geodata/50N120E_20101117_gmted_mea075.tif') +# dta11 = f11.read() +# f12 = rasterio.open(path_pref+'geodata/50N150E_20101117_gmted_mea075.tif') +# dta12 = f12.read() +# # +# combined_data = merge([f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12]) +# full_data = combined_data[0][0] +# transform = combined_data[1] +# topo_dataset = rasterio.open(path_pref+'geodata/merged.tif', 'w',driver='GTiff', +# height=full_data.shape[0],width=full_data.shape[1],count=1, +# dtype=full_data.dtype,crs='+proj=latlong',transform=transform) +# topo_dataset.write(full_data, 1) +# topo_dataset.close() +topo = rasterio.open(path_pref+'geodata/merged.tif') +topo_data = topo.read() + +import geopandas as gpd +# from shapely.geometry import mapping +# from rasterio import mask as msk +# +# def clip_raster(gdf, img): +# clipped_array, clipped_transform = msk.mask(img, [mapping(gdf.iloc[0].geometry)]) +# clipped_array, clipped_transform = msk.mask(img, [mapping(gdf.iloc[0].geometry)], nodata=(np.amax(clipped_array[0]) + 1)) +# clipped_array[0] = clipped_array[0] + abs(np.amin(clipped_array)) +# value_range = np.amax(clipped_array) + abs(np.amin(clipped_array)) +# return clipped_array, value_range +# +df = gpd.read_file(path_pref+'geodata/ne_10m_admin_0_countries.shp') +# +chi = df.loc[df['ADMIN'] == 'China'] +# clipped_array_chi, clipped_transform_chi = msk.mask(topo, [mapping(chi.iloc[0].geometry)]) +# china_topography, china_value_range = clip_raster(chi, topo) +# # # # # +jap = df.loc[df['ADMIN'] == 'Japan'] +# clipped_array_jap, clipped_transform_jap = msk.mask(topo, [mapping(jap.iloc[0].geometry)]) +# japan_topography, japan_value_range = clip_raster(jap, topo) +# # # # # +rus = df.loc[df['ADMIN'] == 'Russia'] +# clipped_array_rus, clipped_transform_rus = msk.mask(topo, [mapping(rus.iloc[0].geometry)]) +# russia_topography, russia_value_range = clip_raster(rus, topo) +# # # # # +sko = df.loc[df['ADMIN'] == 'South Korea'] +# clipped_array_sko, clipped_transform_sko = msk.mask(topo, [mapping(sko.iloc[0].geometry)]) +# sko_topography, sko_value_range = clip_raster(sko, topo) +# # # # # +nko = df.loc[df['ADMIN'] == 'North Korea'] +# clipped_array_nko, clipped_transform_nko = msk.mask(topo, [mapping(nko.iloc[0].geometry)]) +# nko_topography, nko_value_range = clip_raster(nko, topo) +# # # # # +phi = df.loc[df['ADMIN'] == 'Philippines'] +# clipped_array_phi, clipped_transform_phi = msk.mask(topo, [mapping(phi.iloc[0].geometry)]) +# phi_topography, phi_value_range = clip_raster(phi, topo) +# # # # # # +tai = df.loc[df['ADMIN'] == 'Taiwan'] +# clipped_array_tai, clipped_transform_tai = msk.mask(topo, [mapping(tai.iloc[0].geometry)]) +# tai_topography, tai_value_range = clip_raster(tai, topo) +# # # # # +vie = df.loc[df['ADMIN'] == 'Vietnam'] +# clipped_array_vie, clipped_transform_vie = msk.mask(topo, [mapping(vie.iloc[0].geometry)]) +# vie_topography, vie_value_range = clip_raster(vie, topo) +# # # # # +mon = df.loc[df['ADMIN'] == 'Mongolia'] +# clipped_array_mon, clipped_transform_mon = msk.mask(topo, [mapping(mon.iloc[0].geometry)]) +# mon_topography, mon_value_range = clip_raster(mon, topo) +# mon_topography = mon_topography - np.min(mon_topography) + + +min_lat = -10 +max_lat = 70 +min_lon = 90 +max_lon = 180 +min_lat_plot = -8 +max_lat_plot = 65 +min_lon_plot = 110 +max_lon_plot = 180 +nlat , nlon = np.shape(topo_data[0]) +lat_vals = np.linspace(max_lat,min_lat,nlat) +lon_vals = np.linspace(min_lon,max_lon,nlon) +min_lat_idx = np.where(lat_vals >= min_lat_plot)[0][-1] +max_lat_idx = np.where(lat_vals >= max_lat_plot)[0][-1] +min_lon_idx = np.where(lon_vals >= min_lon_plot)[0][0] +max_lon_idx = np.where(lon_vals >= max_lon_plot)[0][0] +nlat_plot, nlon_plot = np.shape(topo_data[0][max_lat_idx:min_lat_idx,min_lon_idx:max_lon_idx]) +lat_vals_new = np.linspace(min_lat_plot,max_lat_plot,nlat_plot) +lon_vals_new = np.linspace(min_lon_plot,max_lon_plot,nlon_plot) +dx = (lon_vals_new[1]-lon_vals_new[0])/2. +dy = (lat_vals_new[1]-lat_vals_new[0])/2. +extent = [lon_vals_new[0]-dx, lon_vals_new[-1]+dx, lat_vals_new[0]-dy, lat_vals_new[-1]+dy] +topo_vals = topo_data[0][max_lat_idx:min_lat_idx,min_lon_idx:max_lon_idx] +print("topo_vals_shape:",np.shape(topo_vals)) +topo_vals = topo_vals[::80,::80] +lat_vals = lat_vals_new[::80] +lon_vals = lon_vals_new[::80] +print("new topo_vals_shape:",np.shape(topo_vals)) + + +def coord_idx(s, ver_num): + if ver_num == 34 or ver_num == 0: + if 0 <= s <= 144: + return 0 + elif 145 <= s <= 289: + return 1 + elif 290 <= s <= 434: + return 2 + elif 435 <= s <= 579: + return 3 + elif 580 <= s <= 724: + return 4 + elif 725 <= s <= 869: + return 5 + elif 870 <= s <= 1014: + return 6 + else: + return 7 + else: + if 0 <= s <= 143: + return 0 + elif 144 <= s <= 287: + return 1 + elif 288 <= s <= 431: + return 2 + elif 432 <= s <= 575: + return 3 + elif 576 <= s <= 719: + return 4 + elif 720 <= s <= 863: + return 5 + elif 864 <= s <= 1007: + return 6 + else: + return 7 + + +def training_lons(): + return np.array([136.6180, 139.5560, 139.3290, 138.9350, + 140.9290, 135.7400, 141.5010, 142.3870]) + +def training_lats(): + return np.array([33.0700, 28.8560, 28.9320, 29.3840, + 33.4530, 33.1570, 35.9360, 35.2670]) + +def unseen_lons(): + return np.array([136.6500,138.2000,138.9000, + 139.5000,140.2000,140.5000, + 141.5000,142.5000]) + +def unseen_lats(): + return np.array([33.1000,31.0000,28.1000, + 28.8000,29.1000,31.8000, + 34.2000,36.2000]) + +### UNSEEN EPI DISTANCES ### +# (136.65,33.10) - 2.78 miles - +# (138.20,31.00) - 119.97 miles - +# (138.90,28.10) - 63.11 miles - +# (139.50,28.80) - 5.144 miles - +# (140.20,29.10) - 42.42 miles - +# (140.50,31.80) - 117.04 miles - +# (141.50,34.20) - 61.02 miles - +# (142.50,36.20) - 64.77 miles - + +epi_num = 8 +time = 225 # can be 80, 160, 240, or 60, 120, 180, 240 or 75, 150, 225 +split = 8020 # can be 9505 or 8020 + +if time <= 120: + fname = "unseen_agg_8_sims_0_time_ss_2_ss_unstruct_ntimes_3464_wd_new_epis_tf_145.mat" + if split == 9505: + ver_num = 34 + save_path = os.getcwd() + "/lightning_logs/combined_34_22/recons/" + else: + ver_num = 0 + save_path = os.getcwd() + "/lightning_logs/combined_0_2/recons/" + slice = int(time*6/5) + (epi_num-1)*145 - 1 + print("selected slice from 0-2 hr model: ", slice) +else: + fname = "unseen_agg_8_sims_0_time_ss_2_ss_unstruct_ntimes_3464_wd_new_epis_tf_289.mat" + if split == 9505: + ver_num = 22 + save_path = os.getcwd() + "/lightning_logs/combined_34_22/recons/" + else: + ver_num = 2 + save_path = os.getcwd() + "/lightning_logs/combined_0_2/recons/" + slice = int((time-120)*6/5) + (epi_num-1)*144 - 1 + print("selected slice from 2-4 hr model: ",slice) + + +#good for epi 7 +min_lat_plot = 2.5 +max_lat_plot = 57.5 +min_lon_plot = 120 +max_lon_plot = 185 + +#good for epi 3 +# min_lat_plot = 0 +# max_lat_plot = 60 +# min_lon_plot = 115 +# max_lon_plot = 175 + +dx = (lon_vals[1]-lon_vals[0])/2. +dy = (lat_vals[1]-lat_vals[0])/2. +min_lat_idx = np.argsort(np.abs(lat_vals-min_lat_plot))[0] +max_lat_idx = np.argsort(np.abs(lat_vals-max_lat_plot))[0] +min_lon_idx = np.argsort(np.abs(lon_vals-min_lon_plot))[0] +max_lon_idx = np.argsort(np.abs(lon_vals-max_lon_plot))[0] +extent = [lon_vals[min_lon_idx]-dx, lon_vals[max_lon_idx]+dx, + lat_vals[min_lat_idx]-dy, lat_vals[max_lat_idx]+dy] + + +out_path = os.getcwd()+"/lightning_logs/version_"+str(ver_num)+"/" +output_im = torch.load(out_path+'tensor_unseen.pt') + +true_data, latitude, longitude, mask, max_ht, times, sensors, div = SWETsunamiForPlotting(fname) +true_data *= max_ht +output_im *= max_ht + +# NEED INTERPOLATED WAVE_HEIGT, LONGITUDE, LATITUDE, MASK +epi_lons = unseen_lons() +epi_lats = unseen_lats() +epi_idx = coord_idx(slice,ver_num) +tsun_lons_tmp = longitude*(180 / np.pi) +tsun_lats_tmp = latitude*(180 / np.pi) +tsun_lons = tsun_lons_tmp +tsun_lats = tsun_lats_tmp +tsun_lons = tsun_lons[(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot) ] +tsun_lats = tsun_lats[(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot)] + +int_Lons, int_Lats = np.meshgrid(lon_vals, lat_vals) + +epi_lon = epi_lons[epi_idx] +epi_lat = epi_lats[epi_idx] +mins = time +prediction = output_im[slice].cpu().detach().numpy() + +true_data_s = true_data[slice][(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot)] +prediction = prediction[(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot)] + +unstruct_coords = np.array([tsun_lons, tsun_lats]).T +interp_true = np.asarray( + griddata(unstruct_coords, true_data_s, (int_Lons, np.flip(int_Lats)), method='cubic', fill_value=0)) +interp_pred = np.asarray( + griddata(unstruct_coords, prediction, (int_Lons, np.flip(int_Lats)), method='cubic', fill_value=0)) + +interp_true = interp_true[:,:,0] +interp_pred = interp_pred[:,:,0] +interp_abs_err = np.abs(interp_true-interp_pred) + +interp_true[(topo_vals != 0)] = 0.0 +interp_pred[(topo_vals != 0)] = 0.0 +interp_abs_err[(topo_vals != 0)] = 0.0 +min_val = min([np.min(interp_true),np.min(interp_pred),np.min(interp_abs_err)]) +max_val = max([np.max(interp_true),np.max(interp_pred),np.max(interp_abs_err)]) + + +width = (extent[1]-extent[0]) +height = (extent[3]-extent[2]) +aspect = width/height +fig, axs = plt.subplots(nrows=1,ncols=3,figsize=(22.8,6)) +# fig, axs = plt.subplots(nrows=1,ncols=3,figsize=(34.2,9.0)) +from numpy.ma import masked_array +topo_img = masked_array(topo_vals, topo_vals == 0.0) +lons_topo_img = masked_array(int_Lons, topo_vals == 0.0) +lats_topo_img = masked_array(int_Lats, topo_vals == 0.0) +true_tsu_img = masked_array(interp_true, topo_vals != 0.0) +lons_tsu_img = masked_array(int_Lons, topo_vals != 0.0) +lats_tsu_img = masked_array(int_Lats, topo_vals != 0.0) +pred_tsu_img = masked_array(interp_pred, topo_vals != 0.0) +abs_err_img = masked_array(interp_abs_err, topo_vals != 0.0) +shw1 = axs[0].contourf(lons_tsu_img,np.flip(lats_tsu_img),true_tsu_img,levels=50,extent=extent,cmap='bwr') +shw2 = axs[0].contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +shw3 = axs[1].contourf(lons_tsu_img,np.flip(lats_tsu_img),pred_tsu_img,levels=50,extent=extent,cmap='bwr') +shw4 = axs[1].contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +shw5 = axs[2].contourf(lons_tsu_img,np.flip(lats_tsu_img),abs_err_img,levels=50,extent=extent,cmap='bwr') +shw6 = axs[2].contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +shw1.set_clim(min_val,max_val) +shw3.set_clim(min_val,max_val) +shw5.set_clim(min_val,max_val) +import shapely.ops as sops +new_shape = sops.unary_union([el for el in chi['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in rus['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in jap['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in nko['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in sko['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in phi['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in tai['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in vie['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2].plot(*geom.exterior.xy, c='k', lw=1.0) +sensor_longs = sensors[:, 0] +sensor_lats = sensors[:, 1] +axs[0].scatter(sensor_longs * (180 / np.pi), sensor_lats * (180 / np.pi), s=60, marker="^",color='yellow',edgecolor='k') +axs[0].scatter(epi_lon,epi_lat,s=120,marker="x",color='k') + +# Circle Selected Sensors +axs[0].scatter((sensor_longs * (180 / np.pi))[14],(sensor_lats * (180 / np.pi))[14],s=100,facecolors='none', edgecolors='k', linewidth=3) +#axs[0].scatter((sensor_longs * (180 / np.pi))[23],(sensor_lats * (180 / np.pi))[23],s=100,facecolors='none', edgecolors='k',linewidth=3) + +print("circled sensor 1:",(sensor_longs * (180 / np.pi))[14],(sensor_lats * (180 / np.pi))[14]) +#print("circled sensor 2:",(sensor_longs * (180 / np.pi))[23],(sensor_longs * (180 / np.pi))[23]) + +axs[0].set_xlim([extent[0], extent[1]]) +axs[0].set_ylim([extent[2],extent[3]]) +axs[1].scatter(sensor_longs * (180 / np.pi), sensor_lats * (180 / np.pi), s=60, marker="^",color='yellow',edgecolor='k') +axs[1].scatter(epi_lon,epi_lat,s=120,marker="x",color='k') +axs[1].scatter((sensor_longs * (180 / np.pi))[14],(sensor_lats * (180 / np.pi))[14],s=100,facecolors='none', edgecolors='k', linewidth=3) +#axs[1].scatter((sensor_longs * (180 / np.pi))[23],(sensor_lats * (180 / np.pi))[23],s=100,facecolors='none', edgecolors='k',linewidth=3) + + +axs[1].set_xlim([extent[0], extent[1]]) +axs[1].set_ylim([extent[2],extent[3]]) +axs[2].set_xlim([extent[0], extent[1]]) +axs[2].set_ylim([extent[2],extent[3]]) +# axs[0].set_title(r'\textbf{True ($m$)}') +axs[0].set_title(r'\textbf{True (' + r'$\mathbf{m}$' +r'\textbf{)}') +axs[1].set_title(r'\textbf{Predicted}') +axs[2].set_title(r'\textbf{Absolute Error}') +axs[0].set_aspect(aspect) +axs[1].set_aspect(aspect) +axs[2].set_aspect(aspect) +axs[0].set_ylabel(r'\textbf{Latitude ($^\circ$N)}', labelpad=5) +axs[0].set_xlabel(r'\textbf{Longitude ($^\circ$E)}', labelpad=5) +fig.suptitle(r'\textbf{Epicenter: (}' + r'\textbf{'+str(epi_lon) + r'}'+ + r'\textbf{$^\circ$E, }' + r'\textbf{'+str(epi_lat) + r'}'+ + r'\textbf{$^\circ$N), Time: }'+ r'\textbf{'+ str(mins) + r'}' + + r'\textbf{ mins}', y=1.0) + +value1 = 3 +value2 = 5 +fig.colorbar(shw1,ax=axs.ravel().tolist(),fraction=0.047, pad=0.04) +axs[0].grid(color = 'gray', linestyle = '--', linewidth = 0.5) +axs[1].grid(color = 'gray', linestyle = '--', linewidth = 0.5) +axs[2].grid(color = 'gray', linestyle = '--', linewidth = 0.5) +plt.savefig(save_path+"h_recon_epi_{}_time_{}_new.png".format(epi_num,time),dpi=400,bbox_inches='tight') + + diff --git a/tsunami/make_grl_fig4.py b/tsunami/make_grl_fig4.py new file mode 100644 index 0000000..44c5e89 --- /dev/null +++ b/tsunami/make_grl_fig4.py @@ -0,0 +1,843 @@ +import multiprocessing +multiprocessing.set_start_method("fork") +import torch +import matplotlib.pyplot as plt +import numpy as np +from datasets import SWETsunamiForPlotting2 +from scipy.interpolate import LinearNDInterpolator +import scipy.io as sio +import sys +import os +plt.rcParams['text.usetex'] = True +plt.rcParams["font.family"] = "serif" +plt.rcParams["font.size"] = 16 +plt.rcParams['xtick.labelsize'] = 16 +plt.rcParams['ytick.labelsize'] = 16 + +path = os.getcwd()+"/lihfp_figs/" + + + +### EPICENTERS OF UNSEEN DATA ### +def unseen_lons(): + return np.array([136.6500,138.2000,138.9000, + 139.5000,140.2000,140.5000, + 141.5000,142.5000]) + + +def unseen_lats(): + return np.array([33.1000,31.0000,28.1000, + 28.8000,29.1000,31.8000, + 34.2000,36.2000]) + + +epi_lons = unseen_lons() +epi_lats = unseen_lats() +### ### + + +### SIMULATION DATA ### +fname_2hr = "unseen_agg_8_sims_0_time_ss_2_ss_unstruct_ntimes_3464_wd_new_epis_tf_145.mat" +fname_4hr = "unseen_agg_8_sims_0_time_ss_2_ss_unstruct_ntimes_3464_wd_new_epis_tf_289.mat" + +train_split = "8020" #9505 or 8020 +regularization = "unreg" #reg or unreg + +if train_split == "9505" and regularization == "unreg": + ver_num_2hr = 34 + ver_num_4hr = 22 +elif train_split == "9505" and regularization == "reg": + ver_num_2hr = 35 + ver_num_4hr = 25 +elif train_split == "8020" and regularization == "unreg": + ver_num_2hr = 0 + ver_num_4hr = 2 +else: #train_split=8020, regularization=reg + ver_num_2hr = 1 + ver_num_4hr = 3 + +logfile = open(path+"/logs/metrics_{}_{}.out".format(train_split,regularization), 'w') +sys.stdout = logfile + + +type = "unseen" +out_path_2hr = os.getcwd()+"/lightning_logs/version_"+str(ver_num_2hr)+"/" +output_im_2hr = torch.load(out_path_2hr+'tensor'+'_'+str(type)+'.pt').numpy() +out_path_4hr = os.getcwd()+"/lightning_logs/version_"+str(ver_num_4hr)+"/" +output_im_4hr = torch.load(out_path_4hr+'tensor'+'_'+str(type)+'.pt').numpy() +true_data_2hr, latitude_2hr, longitude_2hr, mask_2hr, max_ht_2hr, times_2hr, sensors_2hr, div_2hr = SWETsunamiForPlotting2(fname_2hr) +true_data_2hr *= max_ht_2hr +output_im_2hr *= max_ht_2hr +true_data_4hr, latitude_4hr, longitude_4hr, mask_4hr, max_ht_4hr, times_4hr, sensors_4hr, div_4hr = SWETsunamiForPlotting2(fname_4hr) +true_data_4hr *= max_ht_4hr +output_im_4hr *= max_ht_4hr +tsun_lons = longitude_4hr*(180 / np.pi) +tsun_lats = latitude_4hr*(180 / np.pi) + + + + +### COLLECT SIMULATION DATA INTO ONE FOUR HOUR SET ### +intervals_2hr = [(0,145),(145,290),(290,435),(435,580),(580,725),(725,870),(870,1015),(1015,1160)] +sim1_2hr = true_data_2hr[intervals_2hr[0][0]:intervals_2hr[0][1]] +sim2_2hr = true_data_2hr[intervals_2hr[1][0]:intervals_2hr[1][1]] +sim3_2hr = true_data_2hr[intervals_2hr[2][0]:intervals_2hr[2][1]] +sim4_2hr = true_data_2hr[intervals_2hr[3][0]:intervals_2hr[3][1]] +sim5_2hr = true_data_2hr[intervals_2hr[4][0]:intervals_2hr[4][1]] +sim6_2hr = true_data_2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] +sim7_2hr = true_data_2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] +sim8_2hr = true_data_2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] +intervals_4hr = [(0,144),(144,288),(288,432),(432,576),(576,720),(720,864),(864,1008),(1008,1152)] +sim1_4hr = true_data_4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] +sim2_4hr = true_data_4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] +sim3_4hr = true_data_4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] +sim4_4hr = true_data_4hr[intervals_4hr[3][0]:intervals_4hr[3][1]] +sim5_4hr = true_data_4hr[intervals_4hr[4][0]:intervals_4hr[4][1]] +sim6_4hr = true_data_4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] +sim7_4hr = true_data_4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] +sim8_4hr = true_data_4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] +sim1_full = np.concatenate((sim1_2hr,sim1_4hr),axis=0) +sim2_full = np.concatenate((sim2_2hr,sim2_4hr),axis=0) +sim3_full = np.concatenate((sim3_2hr,sim3_4hr),axis=0) +sim4_full = np.concatenate((sim4_2hr,sim4_4hr),axis=0) +sim5_full = np.concatenate((sim5_2hr,sim5_4hr),axis=0) +sim6_full = np.concatenate((sim6_2hr,sim6_4hr),axis=0) +sim7_full = np.concatenate((sim7_2hr,sim7_4hr),axis=0) +sim8_full = np.concatenate((sim8_2hr,sim8_4hr),axis=0) + +### COLLECT SENSEIVER RECONS INTO ONE FOUR HOUR SET ### +sens1_2hr = output_im_2hr[intervals_2hr[0][0]:intervals_2hr[0][1]] +sens2_2hr = output_im_2hr[intervals_2hr[1][0]:intervals_2hr[1][1]] +sens3_2hr = output_im_2hr[intervals_2hr[2][0]:intervals_2hr[2][1]] +sens4_2hr = output_im_2hr[intervals_2hr[3][0]:intervals_2hr[3][1]] +sens5_2hr = output_im_2hr[intervals_2hr[4][0]:intervals_2hr[4][1]] +sens6_2hr = output_im_2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] +sens7_2hr = output_im_2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] +sens8_2hr = output_im_2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] +sens1_4hr = output_im_4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] +sens2_4hr = output_im_4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] +sens3_4hr = output_im_4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] +sens4_4hr = output_im_4hr[intervals_4hr[3][0]:intervals_4hr[3][1]] +sens5_4hr = output_im_4hr[intervals_4hr[4][0]:intervals_4hr[4][1]] +sens6_4hr = output_im_4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] +sens7_4hr = output_im_4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] +sens8_4hr = output_im_4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] +sens1_full = np.concatenate((sens1_2hr,sens1_4hr),axis=0) +sens2_full = np.concatenate((sens2_2hr,sens2_4hr),axis=0) +sens3_full = np.concatenate((sens3_2hr,sens3_4hr),axis=0) +sens4_full = np.concatenate((sens4_2hr,sens4_4hr),axis=0) +sens5_full = np.concatenate((sens5_2hr,sens5_4hr),axis=0) +sens6_full = np.concatenate((sens6_2hr,sens6_4hr),axis=0) +sens7_full = np.concatenate((sens7_2hr,sens7_4hr),axis=0) +sens8_full = np.concatenate((sens8_2hr,sens8_4hr),axis=0) + + +### SENSOR INDICES, LOCATIONS AND BATHYMETRY ### +sensors = sensors_2hr*(180/np.pi) +mat_data = sio.loadmat(os.getcwd()+"/Data/tsunami/"+fname_2hr) +sensor_indices = mat_data['sensor_loc_indices'] +sens_lons = tsun_lons[sensor_indices[0]] +sens_lats = tsun_lats[sensor_indices[0]] +bathymetry = mat_data['ocn_floor'] +xy = np.c_[tsun_lons, tsun_lats] +bath = LinearNDInterpolator(xy, bathymetry[0]) + +### RESTRICT DATA TO SENSIBLE WINDOW ### +min_lat_plot = 10 +max_lat_plot = 45 +min_lon_plot = 125 +max_lon_plot = 160 +in_window_indicator = np.zeros_like(sens_lons) +for i in range(len(sens_lons)): + if min_lon_plot <= sens_lons[i] <= max_lon_plot and min_lat_plot <= sens_lats[i] <= max_lat_plot: + in_window_indicator[i] += 1 +in_window_sensor_indices = sensor_indices[0][np.where(in_window_indicator==1.0)] +sens_lons_inner = tsun_lons[in_window_sensor_indices] +sens_lats_inner = tsun_lats[in_window_sensor_indices] +# fig, ax = plt.subplots() +# ax.scatter(sens_lons_inner,sens_lats_inner) +sim1_sens_vals_inner = sim1_full[:,in_window_sensor_indices][:,:,0].T +sim2_sens_vals_inner = sim2_full[:,in_window_sensor_indices][:,:,0].T +sim3_sens_vals_inner = sim3_full[:,in_window_sensor_indices][:,:,0].T +sim4_sens_vals_inner = sim4_full[:,in_window_sensor_indices][:,:,0].T +sim5_sens_vals_inner = sim5_full[:,in_window_sensor_indices][:,:,0].T +sim6_sens_vals_inner = sim6_full[:,in_window_sensor_indices][:,:,0].T +sim7_sens_vals_inner = sim7_full[:,in_window_sensor_indices][:,:,0].T +sim8_sens_vals_inner = sim8_full[:,in_window_sensor_indices][:,:,0].T + + +### MAKE A MATRIX CONSISTING OF DISTANCES BETWEEN PAIRS OF SENSORS ### +sens_distance_arr = np.zeros(shape=(len(sens_lons_inner),len(sens_lons_inner))) +for i in range(len(sens_lons_inner)): + for j in range(len(sens_lons_inner)): + if i= min_lat_plot)[0][-1] +# max_lat_idx = np.where(lat_vals >= max_lat_plot)[0][-1] +# min_lon_idx = np.where(lon_vals >= min_lon_plot)[0][0] +# max_lon_idx = np.where(lon_vals >= max_lon_plot)[0][0] +# nlat_plot, nlon_plot = np.shape(topo_data[0][max_lat_idx:min_lat_idx,min_lon_idx:max_lon_idx]) +# lat_vals_new = np.linspace(min_lat_plot,max_lat_plot,nlat_plot) +# lon_vals_new = np.linspace(min_lon_plot,max_lon_plot,nlon_plot) +# dx = (lon_vals_new[1]-lon_vals_new[0])/2. +# dy = (lat_vals_new[1]-lat_vals_new[0])/2. +# extent = [lon_vals_new[0]-dx, lon_vals_new[-1]+dx, lat_vals_new[0]-dy, lat_vals_new[-1]+dy] +# topo_vals = topo_data[0][max_lat_idx:min_lat_idx,min_lon_idx:max_lon_idx] +# topo_vals = topo_vals[::80,::80] +# lat_vals = lat_vals_new[::80] +# lon_vals = lon_vals_new[::80] +# dx = (lon_vals[1]-lon_vals[0])/2. +# dy = (lat_vals[1]-lat_vals[0])/2. +# min_lat_idx = np.argsort(np.abs(lat_vals-min_lat_plot))[0] +# max_lat_idx = np.argsort(np.abs(lat_vals-max_lat_plot))[0] +# min_lon_idx = np.argsort(np.abs(lon_vals-min_lon_plot))[0] +# max_lon_idx = np.argsort(np.abs(lon_vals-max_lon_plot))[0] +# extent = [lon_vals[min_lon_idx]-dx, lon_vals[max_lon_idx]+dx, +# lat_vals[min_lat_idx]-dy, lat_vals[max_lat_idx]+dy] +# ocn_floor = bathymetry +# tsun_lons_tmp = tsun_lons +# tsun_lats_tmp = tsun_lats +# tsun_lons_for_bath = tsun_lons_tmp +# tsun_lats_for_bath = tsun_lats_tmp +# tsun_lons_for_bath = tsun_lons[(tsun_lons_tmp >= min_lon_plot) & +# (tsun_lons_tmp <= max_lon_plot) & +# (tsun_lats_tmp >= min_lat_plot) & +# (tsun_lats_tmp <= max_lat_plot) ] +# tsun_lats_for_bath = tsun_lats[(tsun_lons_tmp >= min_lon_plot) & +# (tsun_lons_tmp <= max_lon_plot) & +# (tsun_lats_tmp >= min_lat_plot) & +# (tsun_lats_tmp <= max_lat_plot)] +# int_Lons, int_Lats = np.meshgrid(lon_vals, lat_vals) +# ocn_floor_s = ocn_floor[0][(tsun_lons_tmp >= min_lon_plot) & +# (tsun_lons_tmp <= max_lon_plot) & +# (tsun_lats_tmp >= min_lat_plot) & +# (tsun_lats_tmp <= max_lat_plot)] +# unstruct_coords = np.array([tsun_lons_for_bath, tsun_lats_for_bath]).T +# interp_true = np.asarray( +# griddata(unstruct_coords, ocn_floor_s, (int_Lons, np.flip(int_Lats)), method='cubic', fill_value=0)) +# interp_true[(topo_vals != 0)] = 0.0 +# width = (extent[1]-extent[0]) +# height = (extent[3]-extent[2]) +# aspect = width/height +# fig, ax = plt.subplots(nrows=1,ncols=1,figsize=(13,8)) +# from numpy.ma import masked_array +# topo_img = masked_array(topo_vals, topo_vals == 0.0) +# lons_topo_img = masked_array(int_Lons, topo_vals == 0.0) +# lats_topo_img = masked_array(int_Lats, topo_vals == 0.0) +# true_tsu_img = masked_array(interp_true, topo_vals != 0.0) +# lons_tsu_img = masked_array(int_Lons, topo_vals != 0.0) +# lats_tsu_img = masked_array(int_Lats, topo_vals != 0.0) +# shw1 = ax.contourf(lons_tsu_img,np.flip(lats_tsu_img),true_tsu_img,levels=50,extent=extent,cmap='cool') +# shw2 = ax.contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +# ax.scatter(real_lons,real_lats,s=120,c='w',marker='o', label='DART') +# ax.scatter(virtual_lons,virtual_lats,s=120,c='k',marker='x', label='Virtual') +# +# #ax.scatter(virtual_lons[4],virtual_lats[4],s=250,facecolors='none', edgecolors='r',linewidth=5) +# +# ax.text(virtual_lons[0]+1,virtual_lats[0]+1,'1',color='k', weight='bold',fontsize=32) +# ax.text(virtual_lons[1]+1,virtual_lats[1]+1,'2',color='k', weight='bold',fontsize=32) +# ax.text(virtual_lons[2]+1,virtual_lats[2]+1,'3',color='k', weight='bold',fontsize=32) +# ax.text(virtual_lons[3]+1,virtual_lats[3]+1,'4',color='k', weight='bold',fontsize=32) +# ax.text(virtual_lons[4]+1,virtual_lats[4]+1,'5',color='k', weight='bold',fontsize=32) +# ax.text(virtual_lons[5]+1,virtual_lats[5]+1,'6',color='k', weight='bold',fontsize=32) +# +# ax.legend(loc='upper left') +# ax.set_xlim([extent[0], extent[1]]) +# ax.set_ylim([extent[2],extent[3]]) +# #ax.set_aspect(aspect) +# ax.set_ylabel(r'\textbf{Latitude}', labelpad=10) +# ax.set_xlabel(r'\textbf{Longitude}', labelpad=10) +# fig.colorbar(shw1,ax=ax,fraction=.047, pad=0.04) +# ax.grid(color = 'gray', linestyle = '--', linewidth = 0.5) +# bath_save_dir = "/Users/emcdugald/sparse_sens_tsunami/GRL_figs/virtual_waveforms/" +# #plt.savefig(bath_save_dir+"bath_4.png",bbox_inches='tight',dpi=400) +# plt.savefig(bath_save_dir+"bath_4.png",bbox_inches='tight',dpi=400) +# ##################################################################### + + +min_lat_plot = 10 +max_lat_plot = 45 +min_lon_plot = 125 +max_lon_plot = 160 + +### GET LEFT, RIGHT, AND VIRTUAL BATHYMETRY VALUES ### +real_sens_left_bath = np.zeros(6) +for i in range(6): + bath_lon = tsun_lons[in_window_sensor_indices[five_smallest_indices[0][i]]] + bath_lat = tsun_lats[in_window_sensor_indices[five_smallest_indices[0][i]]] + bathpt = bath(bath_lon,bath_lat) + real_sens_left_bath[i] = bathpt + +real_sens_right_bath = np.zeros(6) +for i in range(6): + bath_lon = tsun_lons[in_window_sensor_indices[five_smallest_indices[1][i]]] + bath_lat = tsun_lats[in_window_sensor_indices[five_smallest_indices[1][i]]] + bathpt = bath(bath_lon,bath_lat) + real_sens_right_bath[i] = bathpt + + +virtual_sens_bath = np.zeros(6) +for i in range(6): + vlon = virtual_lons[i] + vlat = virtual_lats[i] + bathpt = bath(vlon, vlat) + virtual_sens_bath[i] = bathpt + +print("Real Sensor Pair Locations: ",sensor_pair_locs) +print("Virtual Sensor Locations: ",virtual_sensor_locs) +print("Left Sens Bathymetry: ",real_sens_left_bath) +print("Right Sens Bathymetry: ",real_sens_right_bath) +print("Virtual Sens Bathymetry: ",virtual_sens_bath) + +############################## +### GET THE REAL WAVEFORMS ### +all_inner_sens_vals = [sim1_sens_vals_inner,sim2_sens_vals_inner,sim3_sens_vals_inner,sim4_sens_vals_inner, + sim5_sens_vals_inner,sim6_sens_vals_inner,sim7_sens_vals_inner,sim8_sens_vals_inner] + +all_true_for_interp = [sim1_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim2_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim3_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim4_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim5_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim6_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim7_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim8_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)]] + +all_senseiver_for_interp = [sens1_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens2_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens3_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens4_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens5_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens6_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens7_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens8_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)]] + +interp_lons = tsun_lons[(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)] + +interp_lats = tsun_lats[(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)] + +xy_inner = np.c_[interp_lons, interp_lats] + +sens_mean_err_all_sims = [] +wang_mean_err_all_sims = [] +# sens_mean_at_err_all_sims = [] +# wang_mean_at_err_all_sims = [] +# sens_mean_ma_err_all_sims = [] +# wang_mean_ma_err_all_sims = [] +true_max_amplitudes_full = [] +sens_max_amplitudes_full = [] +wang_max_amplitudes_full = [] +true_arrival_times_full = [] +sens_arrival_times_full = [] +wang_arrival_times_full = [] + +for sim_num in range(8): +#for sim_num in [2,4,6,7]: + display_num = sim_num + 1 + print("########## Simulation {} ##########".format(display_num)) + real_sens_left_waveforms = np.zeros(shape=(6,289)) + for i in range(6): + real_sens_left_waveforms[i] = all_inner_sens_vals[sim_num][five_smallest_indices[0][i]] + + real_sens_right_waveforms = np.zeros(shape=(6,289)) + for i in range(6): + real_sens_right_waveforms[i] = all_inner_sens_vals[sim_num][five_smallest_indices[1][i]] + + ### GET THE ARRIVAL TIMES FOR REAL WAVEFORMS ### + real_sens_left_arrival_times = np.zeros(6) + real_sens_left_arrival_time_indices = np.zeros(6) + for i in range(6): + t = 0 + idx = 0 + for val in real_sens_left_waveforms[i]: + if val < 1e-3: + t += 50.0 + idx += 1 + else: + real_sens_left_arrival_times[i] += t + real_sens_left_arrival_time_indices[i] += idx + break + + real_sens_right_arrival_times = np.zeros(6) + real_sens_right_arrival_time_indices = np.zeros(6) + for i in range(6): + t = 0 + idx = 0 + for val in real_sens_right_waveforms[i]: + if val < 1e-3: + t += 50.0 + idx += 1 + else: + real_sens_right_arrival_times[i] += t + real_sens_right_arrival_time_indices[i] += idx + break + + wang_arrival_times = .5*real_sens_left_arrival_times+.5*real_sens_right_arrival_times + + t_seconds = np.arange(0, 50 * 289, 50) + t_minutes = t_seconds/60 + wang_arrival_time_indices = np.zeros(6) + for i in range(6): + wang_arrival_time_indices[i] += np.argmin(np.abs(wang_arrival_times[i]-t_seconds)) + + interpolated_virtual_sens_waveforms = np.zeros(shape=(6,289)) + stop_plot_indices = np.zeros(6) + for i in range(6): + l_idx = int(real_sens_left_arrival_time_indices[i]) + r_idx = int(real_sens_right_arrival_time_indices[i]) + v_idx = int(wang_arrival_time_indices[i]) + smaller_idx = min(l_idx,r_idx) + res = int(289 - v_idx) + left_wf_shifted = real_sens_left_waveforms[i][smaller_idx:smaller_idx+res] + right_wf_shifted = real_sens_right_waveforms[i][smaller_idx:smaller_idx+res] + interpolated_virtual_sens_waveforms[i][v_idx:v_idx+res] = (((.5*left_wf_shifted)* + ((-real_sens_left_bath[i])**(.25))+ + (.5*right_wf_shifted)* + ((-real_sens_right_bath[i])**(.25))) + /((-virtual_sens_bath[i])**(.25))) + + for i in range(6): + fig, axs = plt.subplots(nrows=3,ncols=1, figsize=(8,8)) + axs[0].plot(t_minutes,real_sens_left_waveforms[i]) + axs[1].plot(t_minutes,real_sens_right_waveforms[i]) + axs[2].plot(t_minutes,interpolated_virtual_sens_waveforms[i]) + axs[2].set(xlabel=r"\textbf{Time (minutes)}", ylabel=r"\textbf{Wave Height (m)}") + axs[0].set_title(r'\textbf{Real Waveform A}') + axs[1].set_title(r'\textbf{Real Waveform B}') + axs[2].set_title(r'\textbf{Virtual Waveform (via LIHFP)}') + axs[0].set_xlim(0, 240) + axs[1].set_xlim(0, 240) + axs[2].set_xlim(0, 240) + fig.suptitle(r'\textbf{Epicenter: (}' + r'\textbf{' + str(round(epi_lons[sim_num], 2)) + r'}' + + r'\textbf{$^\circ$E, }' + r'\textbf{' + str(round(epi_lats[sim_num], 2)) + r'}' + + r'\textbf{$^\circ$N), '+r'\textbf{Virtual Location }'+ r'\textbf{' +str(i+1) + r'}') + plt.tight_layout(pad=1.0) + plt.savefig("/Users/emcdugald/sparse_sens_tsunami/" + "GRL_figs/virtual_waveforms/figs_v4/" + "LIHFP_split_{}_sim_{}_reg_{}_sensnum_{}.png".format( + train_split, display_num,regularization, i + 1), bbox_inches='tight' + ) + plt.close() + + + + real_virtual_sens_waveforms = np.zeros(shape=(6, 289)) + for i in range(6): + vlon = virtual_lons[i] + vlat = virtual_lats[i] + for j in range(289): + tsu = LinearNDInterpolator(xy_inner, all_true_for_interp[sim_num][j,:,0]) + real_virtual_sens_waveforms[i,j] = tsu(vlon,vlat) + + senseiver_virtual_sens_waveforms = np.zeros(shape=(6, 289)) + for i in range(6): + vlon = virtual_lons[i] + vlat = virtual_lats[i] + for j in range(289): + senseiver = LinearNDInterpolator(xy_inner, all_senseiver_for_interp[sim_num][j,:,0]) + senseiver_virtual_sens_waveforms[i,j] = senseiver(vlon,vlat) + + true_arrival_times = np.zeros(6) + true_arrival_time_indices = np.zeros(6) + for i in range(6): + t = 0 + idx = 0 + for val in real_virtual_sens_waveforms[i]: + if val < 1e-3: + t += 50.0 + idx += 1 + else: + true_arrival_times[i] += t + true_arrival_time_indices[i] += idx + break + + true_max_amplitudes = np.zeros(6) + for i in range(6): + true_max_amplitudes[i] = np.max(np.abs(real_virtual_sens_waveforms[i])) + + sens_arrival_times = np.zeros(6) + sens_arrival_time_indices = np.zeros(6) + for i in range(6): + t = 0 + idx = 0 + for val in senseiver_virtual_sens_waveforms[i]: + if val < 1e-3: + t += 50.0 + idx += 1 + else: + sens_arrival_times[i] += t + sens_arrival_time_indices[i] += idx + break + + senseiver_max_amplitudes = np.zeros(6) + for i in range(6): + senseiver_max_amplitudes[i] = np.max(np.abs(senseiver_virtual_sens_waveforms[i])) + + wang_max_amplitudes = np.zeros(6) + for i in range(6): + wang_max_amplitudes[i] = np.max(np.abs(interpolated_virtual_sens_waveforms[i])) + + print("Left Sens Arrival Times: ", real_sens_left_arrival_times/60) + print("Right Sens Arrival Times: ", real_sens_right_arrival_times/60) + print("Virtual Arrival Times (True): ", true_arrival_times/60) + print("Virtual Arrival Times (Wang): ", wang_arrival_times/60) + print("Virtual Arrival Times (Sens): ", sens_arrival_times/60) + print("Virtual Max Amp (True): ", true_max_amplitudes) + print("Virtual Max Amp (Wang): ", wang_max_amplitudes) + print("Virtual Max Amp (Sens): ", senseiver_max_amplitudes) + print("Wang AT MAE: ",np.mean(np.abs(true_arrival_times-wang_arrival_times))/60) + print("Senseiver AT MAE: ",np.mean(np.abs(true_arrival_times-sens_arrival_times))/60) + print("Wang MA MAE: ",np.mean(np.abs(true_max_amplitudes-wang_max_amplitudes))) + print("Senseiver MA MAE: ",np.mean(np.abs(true_max_amplitudes-senseiver_max_amplitudes))) + true_max_amplitudes_full.append(true_max_amplitudes) + sens_max_amplitudes_full.append(senseiver_max_amplitudes) + wang_max_amplitudes_full.append(wang_max_amplitudes) + true_arrival_times_full.append(true_arrival_times/60) + sens_arrival_times_full.append(sens_arrival_times/60) + wang_arrival_times_full.append(wang_arrival_times/60) + + senseiver_mean_arr = [] + wang_mean_arr = [] + for i in range(6): + times = (5/6)*np.arange(0, 289, 1) + + fig, axs = plt.subplots(nrows=2,ncols=1, figsize=(9,8)) + + axs[0].plot(times,real_virtual_sens_waveforms[i], c='k') + axs[0].plot(times,senseiver_virtual_sens_waveforms[i], c='g') + axs[0].plot(times,interpolated_virtual_sens_waveforms[i],c='r') + axs[0].legend(['SWE', 'Senseiver', 'LIHFP']) + # txt = r'\textbf{Wave Height at Virtual Sensor (}' + r'\textbf{' + str( + # round(virtual_lons[i],2)) + r'}' + r'\textbf{$^\circ$E, }' + r'\textbf{' + str( + # round(virtual_lats[i],2)) + r'}' + r'\textbf{$^\circ$N)}' + txt = r'\textbf{Wave Height at Virtual Sensor }'+ r'\textbf{' + str(i+1) + r'}' + axs[0].set_title(txt) + #axs[0].set_title('Wave Height(m) at ({}$^\circ$E, {}$^\circ$N)'.format(round(virtual_lons[i],2),round(virtual_lats[i],2))) + axs[0].set(xlabel=r"\textbf{Time (minutes)}", ylabel=r"\textbf{Wave Height(m)}") + axs[0].set_xlim(0,240) + + + axs[1].plot(times,np.abs(real_virtual_sens_waveforms[i] - senseiver_virtual_sens_waveforms[i]), c='k') + axs[1].plot(times,np.abs(real_virtual_sens_waveforms[i] - interpolated_virtual_sens_waveforms[i]), c='r') + axs[1].legend([r'\textbf{Senseiver: MAE = }' + r'\textbf{' + str(round(np.mean(np.abs(real_virtual_sens_waveforms[i] - senseiver_virtual_sens_waveforms[i])),2)) + r'}', + r'\textbf{LIHFP: MAE = }' + r'\textbf{' + str(round(np.mean(np.abs(real_virtual_sens_waveforms[i] - interpolated_virtual_sens_waveforms[i])),2)) + r'}']) + axs[1].set_title(r"\textbf{Senseiver vs LIHFP Absolute Error}") + axs[1].set(xlabel=r"\textbf{Time (minutes)}", ylabel=r"$\mathbf{|h-\hat{h}|}$") + axs[1].set_xlim(0, 240) + fig.suptitle(r'\textbf{Epicenter: (}' + r'\textbf{' + str(round(epi_lons[sim_num],3)) + r'}' + + r'\textbf{$^\circ$E, }' + r'\textbf{' + str(round(epi_lats[sim_num],3)) + r'}' + + r'\textbf{$^\circ$N)', y=1.0) + + + + plt.tight_layout() + plt.savefig(path+"/figs/"+ + "split_{}_sim_{}_senslon_{}_senslat_{}_reg_{}_sensnum_{}.png".format( + train_split,display_num, + round(virtual_lons[i],2), + round(virtual_lats[i],2), + regularization,i+1) + ) + plt.close() + print("Senseiver Mean Error for sensor {}:".format(i+1),np.mean(np.abs(real_virtual_sens_waveforms[i]-senseiver_virtual_sens_waveforms[i]))) + print("Wang Mean Error for sensor {}:".format(i+1),np.mean(np.abs(real_virtual_sens_waveforms[i]-interpolated_virtual_sens_waveforms[i]))) + senseiver_mean_arr.append(np.mean(np.abs(real_virtual_sens_waveforms[i]-senseiver_virtual_sens_waveforms[i]))) + wang_mean_arr.append(np.mean(np.abs(real_virtual_sens_waveforms[i]-interpolated_virtual_sens_waveforms[i]))) + print("Mean Senseiver Error for Sim {}:".format(display_num), + np.mean(senseiver_mean_arr)) + print("Mean Wang Error for Sim {}:".format(display_num), + np.mean(wang_mean_arr)) + sens_mean_err_all_sims.append(np.mean(senseiver_mean_arr)) + wang_mean_err_all_sims.append(np.mean(wang_mean_arr)) + +print("Senseiver Mean Error for all Sims: ",np.mean(sens_mean_err_all_sims)) +print("Wang Mean Error for all Sims: ",np.mean(wang_mean_err_all_sims)) +print("Senseiver Error STD for all Sims: ",np.std(sens_mean_err_all_sims)) +print("Wang Error STD for all Sims: ",np.std(wang_mean_err_all_sims)) + + +true_max_amp_mat = np.zeros(shape=(8,6)) +for i in range(8): + for j in range(6): + true_max_amp_mat[i,j] = true_max_amplitudes_full[i][j] + +wang_max_amp_mat = np.zeros(shape=(8,6)) +for i in range(8): + for j in range(6): + wang_max_amp_mat[i,j] = wang_max_amplitudes_full[i][j] + +sens_max_amp_mat = np.zeros(shape=(8,6)) +for i in range(8): + for j in range(6): + sens_max_amp_mat[i,j] = sens_max_amplitudes_full[i][j] + +fig, axs = plt.subplots(nrows=1,ncols=3,figsize=(16,8)) +im1 = axs[0].imshow(true_max_amp_mat) +axs[0].set_title(r'\textbf{True}') +axs[0].set_xlabel(r'\textbf{Virtual Sensors}') +axs[0].set_ylabel(r'\textbf{Simulations}') +im2 = axs[1].imshow(wang_max_amp_mat) +axs[1].set_title(r'\textbf{LIHFP}') +im3 = axs[2].imshow(sens_max_amp_mat) +axs[2].set_title(r'\textbf{Senseiver}') +#fig.colorbar(im1,ax=axs.ravel().tolist(),fraction=0.047, pad=0.04) +pos = axs[2].get_position() +cax = fig.add_axes([pos.x1 + 0.02, pos.y0, 0.02, pos.height]) +cbar = plt.colorbar(im3, cax=cax) + + +fig.suptitle(r'\textbf{Max Amplitudes}',y=.95) +plt.savefig(path+"/figs/"+"split_{}_reg_{}_MaxAmpMat.png".format( + train_split, + regularization) + ) + +true_arr_time_mat = np.zeros(shape=(8,6)) +for i in range(8): + for j in range(6): + true_arr_time_mat[i, j] = true_arrival_times_full[i][j] + +wang_arr_time_mat = np.zeros(shape=(8,6)) +for i in range(8): + for j in range(6): + wang_arr_time_mat[i, j] = wang_arrival_times_full[i][j] + +sens_arr_time_mat = np.zeros(shape=(8,6)) +for i in range(8): + for j in range(6): + sens_arr_time_mat[i, j] = sens_arrival_times_full[i][j] + +fig, axs = plt.subplots(nrows=1,ncols=3,figsize=(16,8)) +im1 = axs[0].imshow(true_arr_time_mat) +axs[0].set_title(r'\textbf{True}') +axs[0].set_xlabel(r'\textbf{Virtual Sensors}') +axs[0].set_ylabel(r'\textbf{Simulations}') +im2 = axs[1].imshow(wang_arr_time_mat) +axs[1].set_title(r'\textbf{LIHFP}') +im3 = axs[2].imshow(sens_arr_time_mat) +axs[2].set_title(r'\textbf{Senseiver}') +#fig.colorbar(im1,ax=axs.ravel().tolist(),fraction=0.047, pad=0.04) + +pos = axs[2].get_position() +cax = fig.add_axes([pos.x1 + 0.02, pos.y0, 0.02, pos.height]) +cbar = plt.colorbar(im3, cax=cax) + +fig.suptitle(r'\textbf{Arrival Times}',y=.95) +plt.savefig(path+"/figs/"+ + "split_{}_reg_{}_ArrivalTimeMat.png".format( + train_split, + regularization) + ) + +wang_ma_err = np.abs(true_max_amp_mat-wang_max_amp_mat) +sens_ma_err = np.abs(true_max_amp_mat-sens_max_amp_mat) + + +fig, axs = plt.subplots(nrows=1,ncols=2,figsize=(13,8)) +im1 = axs[0].imshow(wang_ma_err) +axs[0].set_title(r'\textbf{LIHFP MAE: }' + r'\textbf{'+str(round(np.mean(wang_ma_err),2))+r'}'r'\textbf{ (m)}') +axs[0].set_xlabel(r'\textbf{Virtual Sensor Number}') +axs[0].set_ylabel(r'\textbf{Simulation Number}') +im2 = axs[1].imshow(sens_ma_err) +axs[1].set_title(r'\textbf{Senseiver MAE: }'+ r'\textbf{'+str(round(np.mean(sens_ma_err),2))+r'}'+r'\textbf{ (m)}') +fig.suptitle(r'\textbf{Max Amplitude MAE (all simulations)}',y=.95) + +pos = axs[1].get_position() +cax = fig.add_axes([pos.x1 + 0.02, pos.y0, 0.02, pos.height]) +cbar = plt.colorbar(im2, cax=cax) +plt.subplots_adjust(wspace=.1) +#plt.tight_layout() +#fig.colorbar(im1,ax=axs.ravel().tolist(),fraction=0.05, pad=0.04) +plt.savefig(path+"/figs/"+ + "split_{}_reg_{}_MaxAmpMatErr.png".format( + train_split, + regularization),bbox_inches='tight') + + + +wang_at_err = np.abs(true_arr_time_mat-wang_arr_time_mat) +sens_at_err = np.abs(true_arr_time_mat-sens_arr_time_mat) + +fig, axs = plt.subplots(nrows=1,ncols=2,figsize=(13,8)) +im1 = axs[0].imshow(wang_at_err) +axs[0].set_title(r'\textbf{LIHFP MAE: }' + r'\textbf{'+str(round(np.mean(wang_at_err),2))+r'}'r'\textbf{ (mins)}') +axs[0].set_xlabel(r'\textbf{Virtual Sensor Number}') +axs[0].set_ylabel(r'\textbf{Simulation Number}') +im2 = axs[1].imshow(sens_at_err) +axs[1].set_title(r'\textbf{Senseiver MAE: }' + r'\textbf{'+str(round(np.mean(sens_at_err),2))+r'}'r'\textbf{ (mins)}') +fig.suptitle(r'\textbf{Arrival Time MAE (all simulations)}',y=.95) + +pos = axs[1].get_position() +cax = fig.add_axes([pos.x1 + 0.02, pos.y0, 0.02, pos.height]) +cbar = plt.colorbar(im2, cax=cax) +plt.subplots_adjust(wspace=.1) +#plt.tight_layout() +#fig.colorbar(im1,ax=axs.ravel().tolist(),fraction=0.05, pad=0.04) +plt.savefig(path+"/figs/"+ + "split_{}_reg_{}_ArrivalTimeMatErr.png".format( + train_split,regularization),bbox_inches='tight') + + +#make wang max amp colormap +#make sens max amp colormap +#make true at colormap +#make wang at colormap +#make sens at colormap +#make sens ma err cmap +#make wang ma err cmap +#make sens at err cmap +#make wang at err cmap + +true_max_amplitudes_full = np.array(true_max_amplitudes_full).flatten() +sens_max_amplitudes_full = np.array(sens_max_amplitudes_full).flatten() +wang_max_amplitudes_full = np.array(wang_max_amplitudes_full).flatten() +true_arrival_times_full = np.array(true_arrival_times_full).flatten() +sens_arrival_times_full = np.array(sens_arrival_times_full).flatten() +wang_arrival_times_full = np.array(wang_arrival_times_full).flatten() + +print("Senseiver Max Amp MAE for all sims: ",np.mean(np.abs(true_max_amplitudes_full-sens_max_amplitudes_full))) +print("Wang Max Amp MAE for all sims: ",np.mean(np.abs(true_max_amplitudes_full-wang_max_amplitudes_full))) +print("Senseiver Arrival Time MAE for all sims: ",np.mean(np.abs(true_arrival_times_full-sens_arrival_times_full))) +print("Wang Arrival Time MAE for all sims: ",np.mean(np.abs(true_arrival_times_full-wang_arrival_times_full))) +logfile.close() + +fig, ax = plt.subplots(figsize = (7,6)) +ax.plot(true_max_amplitudes_full,color='k') +ax.plot(sens_max_amplitudes_full,color='r') +ax.plot(wang_max_amplitudes_full,color='g') +ax.legend(['True MA', 'Sens MA', 'LIHFP MA']) +plt.tight_layout() +plt.savefig(path+"/figs/"+ + "split_{}_reg_{}_MaxAmps.png".format( + train_split, + regularization) + ) + +fig, ax = plt.subplots(figsize = (7,6)) +ax.plot(np.abs(true_max_amplitudes_full-wang_max_amplitudes_full),color='k') +ax.plot(np.abs(true_max_amplitudes_full-sens_max_amplitudes_full),color='r') +ax.legend(['Wang MA Err', 'Sens MA Err']) +plt.tight_layout() +plt.savefig(path+"/figs/"+ + "split_{}_reg_{}_MaxAmpsErrs.png".format( + train_split, + regularization) + ) + + +fig, ax = plt.subplots(figsize = (7,6)) +ax.plot(true_arrival_times_full,color='k') +ax.plot(sens_arrival_times_full,color='r') +ax.plot(wang_arrival_times_full,color='g') +ax.legend(['True AT', 'Sens AT', 'LIHFP AT']) +plt.tight_layout() +plt.savefig(path+"/figs/"+ + "split_{}_reg_{}_ArrivalTimes.png".format( + train_split, + regularization) + ) + +fig, ax = plt.subplots(figsize = (7,6)) +ax.plot(np.abs(true_arrival_times_full-wang_arrival_times_full),color='k') +ax.plot(np.abs(true_arrival_times_full-sens_arrival_times_full),color='r') +ax.legend(['Wang AT Err', 'Sens AT Err']) +plt.tight_layout() +plt.savefig(path+"/figs/"+ + "split_{}_reg_{}_ArrivalTimesErrs.png".format( + train_split, + regularization) + ) \ No newline at end of file diff --git a/tsunami/quick_plot_combined.py b/tsunami/quick_plot_combined.py new file mode 100644 index 0000000..34e3637 --- /dev/null +++ b/tsunami/quick_plot_combined.py @@ -0,0 +1,92 @@ +from glob import glob as gb + +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping + +from s_parser import parse_args +from dataloaders import senseiver_dataloader +from network_light import Senseiver + +from combined_plot import plot_all_ts_tsu_japan_unstruct + +import multiprocessing +import os + +num_gpus = torch.cuda.device_count() +print("Num Devices:", num_gpus) +print("GPU Info:", [torch.cuda.get_device_name(i) for i in range(num_gpus)]) + +# arg parser +data_config, encoder_config, decoder_config = parse_args() + +multiprocessing.set_start_method("fork") + +# load the simulation data and create a dataloader +dataloader1 = senseiver_dataloader(data_config, num_workers=4) + +data_config2 = data_config +data_config2['data_key'] = "unseen_agg_8_sims_0_time_ss_2_ss_unstruct_ntimes_3464_wd_new_epis_tf_289.mat" +dataloader2 = senseiver_dataloader(data_config2, num_workers=4) + + +model_num1 = encoder_config['load_model_num'] +if model_num1 == 34: + model_num2 = 22 +else: + model_num2 = 2 + + +print(f'Loading {model_num1} ...') + +model_loc1 = gb(f"lightning_logs/version_{model_num1}/checkpoints/*.ckpt")[0] +# Use the below commented code if using on HPC +# model = Senseiver.load_from_checkpoint(model_loc, +# **encoder_config, +# **decoder_config, +# **data_config) +model1 = Senseiver.load_from_checkpoint(model_loc1, map_location=torch.device('cpu'), + **encoder_config, + **decoder_config, + **data_config) + +print(f'Loading {model_num2} ...') + +model_loc2 = gb(f"lightning_logs/version_{model_num2}/checkpoints/*.ckpt")[0] +# Use the below commented code if using on HPC +# model = Senseiver.load_from_checkpoint(model_loc, +# **encoder_config, +# **decoder_config, +# **data_config) +model2 = Senseiver.load_from_checkpoint(model_loc2, map_location=torch.device('cpu'), + **encoder_config, + **decoder_config, + **data_config2) + + +name = 'tensor.pt' +unseen_flag = data_config['unseen_flag'] +mat_data1 = data_config['data_key'] + +with torch.no_grad(): + if data_config['unseen_flag'] == True: + output_im1 = torch.load(os.getcwd()+"/lightning_logs/version_"+str(model_num1)+'/tensor'+'_unseen'+'.pt') + output_im2 = torch.load(os.getcwd()+"/lightning_logs/version_"+str(model_num2)+'/tensor'+'_unseen'+'.pt') + else: + output_im1 = torch.load(os.getcwd() + "/lightning_logs/version_" + str(model_num1) + '/tensor' + '_training' + '.pt') + output_im2 = torch.load(os.getcwd() + "/lightning_logs/version_" + str(model_num2) + '/tensor' + '_training' + '.pt') + +if encoder_config['load_model_num'] == 34: + path = "/Users/emcdugald/sparse_sens_tsunami/lightning_logs/combined_34_22/" +else: + path = "/Users/emcdugald/sparse_sens_tsunami/lightning_logs/combined_0_2/" + +if unseen_flag: + name = 'tensor_unseen.pt' + type = 'unseen' +else: + name = 'tensor_training.pt' + type = 'training' +plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2, output_im1, output_im2, mat_data1, type, path) + + From 2c0d864a3e76b327202e3fe4dc21b07ed1cd1d4c Mon Sep 17 00:00:00 2001 From: Edward McDugald Date: Wed, 19 Feb 2025 22:49:35 -0700 Subject: [PATCH 2/4] added grl figs --- tsunami/combined_plot.py | 550 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 550 insertions(+) create mode 100644 tsunami/combined_plot.py diff --git a/tsunami/combined_plot.py b/tsunami/combined_plot.py new file mode 100644 index 0000000..8bd5e69 --- /dev/null +++ b/tsunami/combined_plot.py @@ -0,0 +1,550 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import torch +import matplotlib.pyplot as plt +import os +import scipy.io as sio +import numpy as np +from matplotlib.ticker import FormatStrFormatter +plt.rcParams['text.usetex'] = True +plt.rcParams["font.family"] = "serif" +plt.rcParams["font.size"]=16 +plt.rcParams['xtick.labelsize']=16 +plt.rcParams['ytick.labelsize']=16 +plt.rcParams['text.latex.preamble'] = r'\usepackage{bm}' + +def training_lons(): + return np.array([136.6180, 139.5560, 139.3290, 138.9350, 140.9290, 135.7400, 141.5010, 142.3870]) + +def training_lats(): + return np.array([33.0700, 28.8560, 28.9320, 29.3840, 33.4530, 33.1570, 35.9360, 35.2670]) + +def unseen_lons(): + return np.array([136.6500, 140.2000, 138.9000, 139.5000]) + +def unseen_lats(): + return np.array([33.1000, 29.1000, 28.1000, 28.8000]) + + +def unseen_lons_new(): + return np.array([136.6500,138.2000,138.9000, + 139.5000,140.2000,140.5000, + 141.5000,142.5000]) + +def unseen_lats_new(): + return np.array([33.1000,31.0000,28.1000, + 28.8000,29.1000,31.8000, + 34.2000,36.2000]) + + + +def plot_err_for_each_epi(means, times, dta_type, path, plot_pts): + time_idxs = np.insert(times,0,0) + pts = np.sort(plot_pts) + + if dta_type=='training': + lons = training_lons() + lats = training_lats() + else: + lons = unseen_lons_new() + lats = unseen_lats_new() + + j=0 + for i in range(1, len(time_idxs)): + fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7.2,6)) + emean = means[int(time_idxs[i - 1]):int(time_idxs[i])] + epi_lon = lons[j] + epi_lat = lats[j] + if dta_type == 'training': + start_plt_pts = (i - 1) * (np.shape(emean)[0]) + end_plt_pts = (i) * (np.shape(emean)[0]) + plot_points = pts[(pts > start_plt_pts) & (pts < end_plt_pts)] - start_plt_pts + ax.plot(emean) + ax.scatter(plot_points*(5./6.), torch.tensor(emean)[plot_points], c='r', marker='o', s=5) + ax.title.set_text(f'Average Error = {torch.tensor(emean).mean():0.3}') + #ax.set(xlabel=r"\textbf{Time Steps}", ylabel="$\mathbf{|true-pred|/\max(|true|)}$") + + ax.set_xlabel(r"\textbf{Time Steps}") + ax.set_ylabel(r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$", rotation='horizontal', + horizontalalignment='right', verticalalignment='center',fontsize=16) + + + ax.legend(['all data', 'training data', '.10 threshold', '.05 threshold']) + else: + plot_pts = np.arange(0,len(emean),1)*(5./6.) + plot_pts = plot_pts + ax.plot(plot_pts,emean) + + arr = np.array(emean) + condition = arr < 0.10 + for i in range(len(arr)): + if condition[i] and np.all(condition[i:]): # Check if all subsequent elements are also less than 0.10 + lt_idx = i + break + else: + lt_idx = None # If no such index exists + + text = r'\textbf{Average Error = }' + r'\textbf{'+str(round(torch.tensor(emean).mean().item(),3)) + r'}' + ax.title.set_text(text) + # ax.title.set_text(f'Average Error = {torch.tensor(emean).mean():0.3}') + ax.scatter(plot_pts[90],emean[90],s=75,marker="x",color='r') + ax.scatter(plot_pts[180], emean[180], s=75, marker="x", color='r') + ax.scatter(plot_pts[270], emean[270], s=75, marker="x", color='r') + ax.scatter(plot_pts[lt_idx], emean[lt_idx], s=75, marker="x", color='g') + ax.set_xlim(0,240) + ax.set_ylim(0,np.max(emean)) + # ax.set(xlabel="Time (minutes)", ylabel=r"$\frac{|h-\hat{h}|}{\max(|h|)}$") + # ax.legend(['_', + # r'Error at $75$ mins: ${}$'.format(round(emean[90],3)), + # r'Error at $150$ mins: ${}$'.format(round(emean[180],3)), + # r'Error at 225 mins: ${}$'.format(round(emean[270],3)), + # r'$10\%$ threshold time: ${}$ mins'.format(round(plot_pts[lt_idx],1))]) + #ax.set(xlabel=r"\textbf{Time (minutes)}", ylabel=r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$") + ax.set_xlabel(r"\textbf{Time (minutes)}") + ax.set_ylabel(r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$", rotation='horizontal', + horizontalalignment='right', verticalalignment='center',fontsize=16) + + ax.legend(['_', + r'\textbf{Error at $75$ mins: ' + r'\textbf{' + str(round(emean[90],3)) + r'}', + r'\textbf{Error at $150$ mins: ' + r'\textbf{' + str(round(emean[180], 3)) + r'}', + r'\textbf{Error at $225$ mins: ' + r'\textbf{' + str(round(emean[270], 3)) + r'}', + r'\textbf{$10\%$ threshold time: }' + r'\textbf{' +str(round(plot_pts[lt_idx], 1)) + r'}' + r'\textbf{ mins}']) + j += 1 + + + plt.tight_layout() + if path: + plt.savefig(path + '/all_ts_{}'.format(dta_type)+"_epi_{}_{}".format(epi_lon,epi_lat)+"_new.png",dpi=400,bbox_inches='tight') + plt.close() + + + +def plot_timeseries_for_each_epi_unstruct(true,pred,times,sens_idxs,sens_locs,dta_type,path): + time_idxs = np.insert(times, 0, 0) + + if dta_type == 'training': + lons = training_lons() + lats = training_lats() + else: + lons = unseen_lons_new() + lats = unseen_lats_new() + + + for i in range(1, len(time_idxs)): + j = 0 + for idx in sens_idxs[0]: + true_ts = true[int(time_idxs[i - 1]):int(time_idxs[i]), idx] + pred_ts = pred[int(time_idxs[i - 1]):int(time_idxs[i]), idx] + if np.abs(true_ts).max() >= 5e-2: + times = np.arange(0,len(true_ts),1)*(5/6) + times = times + fig, ax = plt.subplots(figsize=(7.2,6)) + ax.plot(times,true_ts, c='g') + ax.plot(times,pred_ts, c='k') + + lon_deg = sens_locs[j][0] * 180 / np.pi + lat_deg = sens_locs[j][1] * 180 / np.pi + ax.set_xlim(0, 240) + + # ax.title.set_text('Wave Height Time Series at sensor ({}$^\circ$E, {}$^\circ$N)'.format(round(lon_deg,1), round(lat_deg,1))) + # ax.set(xlabel="Time (minutes)", ylabel="Wave Height (m)") + # ax.legend(['True Height', 'Predicted Height']) + txt = r'\textbf{Wave Height at Sensor (}' + r'\textbf{' + str( + round(lon_deg, 1)) + r'}' + r'\textbf{$^\circ$E, }' + r'\textbf{' + str( + round(lat_deg, 1)) + r'}' + r'\textbf{$^\circ$N)}' + # txt = r'\textbf{Wave Height Time Series at Sensor (}' + r'\textbf{' + str(round(lon_deg,1))+r'}' + r'\textbf{$^\circ$E}, ' + # + r'\textbf{'+str(round(lat_deg,1))+r'}'+r'\textbf{$^\circ$N)}' + ax.title.set_text(txt) + ax.set(xlabel=r"\textbf{Time (minutes)}", ylabel=r"\textbf{Wave Height (m)}") + ax.legend([r'\textbf{True Height}', r'\textbf{Predicted Height}']) + + plt.tight_layout() + + plt.savefig(path + 'timeseries/{}'.format(dta_type) + + "/sens_{}_{}_epi_{}_{}".format(round(lon_deg,3),round(lat_deg,3),round(lons[i-1],3),round(lats[i-1],3)) + "_new.png",dpi=400,bbox_inches='tight') + plt.close() + + j += 1 + + + + +def plot_div_err_for_each_epi(means, swemeans, times, dta_type, path): + for k in range(len(times)): + times[k] -= (k+1) + time_idxs = np.insert(times,0,0) + + if dta_type=='training': + lons = training_lons() + lats = training_lats() + else: + lons = unseen_lons_new() + lats = unseen_lats_new() + + j=0 + for i in range(1, len(time_idxs)): + fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(7.2,6)) + emean = means[int(time_idxs[i - 1])+8:int(time_idxs[i])] + swe_emean = swemeans[int(time_idxs[i - 1])+8:int(time_idxs[i])] + times = np.arange(0,len(emean),1)*(5/6) + times = times + epi_lon = lons[j] + epi_lat = lats[j] + ax.plot(times, emean, c='g') + ax.plot(times, swe_emean, c='k') + # ax.title.set_text(f'Average $h_t$ error = {torch.tensor(emean).mean():0.3}') + # ax.set(xlabel="Time (minutes)", ylabel=r"$\frac{|h_t-\hat{h_t}|}{\max(|h|)}$") + # ax.legend(['Senseiver Error', 'SWE Error']) + txt = r'\textbf{Average }' +r'$\mathbf{h_t}$' + r'\textbf{ Error = }' + r'\textbf{'+str(round(torch.tensor(emean).mean().item(),3)) + r'}' + ax.title.set_text(txt) + #ax.set(xlabel=r"\textbf{Time (minutes)}", ylabel=r"$\mathbf{\frac{|h_t-\hat{h_t}|}{\max(|h|)}}$") + ax.set_xlabel(r"\textbf{Time (minutes)}") + ax.set_ylabel(r"$\mathbf{\frac{|h_t-\hat{h_t}|}{\max(|h|)}}$", + rotation='horizontal', horizontalalignment='right', + verticalalignment='center',fontsize=16) + ax.legend([r'\textbf{Senseiver Error}', r'\textbf{SWE Error}']) + ax.set_xlim(0,240) + + j += 1 + plt.tight_layout() + if path: + plt.savefig(path + '/all_ts_div_err_{}'.format(dta_type)+"_epi_{}_{}".format(epi_lon,epi_lat)+"_new.png",dpi=400,bbox_inches='tight') + plt.close() + +def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname, data_type,path=None): + with (torch.no_grad()): + + true2hr = dataloader1.dataset.data.to('cpu') + true4hr = dataloader2.dataset.data.to('cpu') + pred2hr = pred1.to('cpu') + pred4hr = pred2.to('cpu') + mask = dataloader1.dataset.mask + times2hr = true2hr.size()[0] + times4hr = true4hr.size()[0] + full_mask_2hr = mask.repeat([times2hr,1]) + full_mask_4hr = mask.repeat([times4hr,1]) + true_masked_2hr = torch.where((full_mask_2hr == 0), true2hr[..., 0], 0) + pred_masked_2hr = torch.where((full_mask_2hr == 0), pred2hr[...,0], 0) + true_masked_4hr = torch.where((full_mask_4hr == 0), true4hr[..., 0], 0) + pred_masked_4hr = torch.where((full_mask_4hr == 0), pred4hr[..., 0], 0) + intervals_2hr = [(0, 145), (145, 290), (290, 435), (435, 580), (580, 725), (725, 870), (870, 1015), + (1015, 1160)] + sim1_2hr = true_masked_2hr[intervals_2hr[0][0]:intervals_2hr[0][1]] + sim2_2hr = true_masked_2hr[intervals_2hr[1][0]:intervals_2hr[1][1]] + sim3_2hr = true_masked_2hr[intervals_2hr[2][0]:intervals_2hr[2][1]] + sim4_2hr = true_masked_2hr[intervals_2hr[3][0]:intervals_2hr[3][1]] + sim5_2hr = true_masked_2hr[intervals_2hr[4][0]:intervals_2hr[4][1]] + sim6_2hr = true_masked_2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] + sim7_2hr = true_masked_2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] + sim8_2hr = true_masked_2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] + intervals_4hr = [(0, 144), (144, 288), (288, 432), (432, 576), (576, 720), (720, 864), (864, 1008), + (1008, 1152)] + sim1_4hr = true_masked_4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] + sim2_4hr = true_masked_4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] + sim3_4hr = true_masked_4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] + sim4_4hr = true_masked_4hr[intervals_4hr[3][0]:intervals_4hr[3][1]] + sim5_4hr = true_masked_4hr[intervals_4hr[4][0]:intervals_4hr[4][1]] + sim6_4hr = true_masked_4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] + sim7_4hr = true_masked_4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] + sim8_4hr = true_masked_4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] + sim1_full = np.concatenate((sim1_2hr, sim1_4hr), axis=0) + sim2_full = np.concatenate((sim2_2hr, sim2_4hr), axis=0) + sim3_full = np.concatenate((sim3_2hr, sim3_4hr), axis=0) + sim4_full = np.concatenate((sim4_2hr, sim4_4hr), axis=0) + sim5_full = np.concatenate((sim5_2hr, sim5_4hr), axis=0) + sim6_full = np.concatenate((sim6_2hr, sim6_4hr), axis=0) + sim7_full = np.concatenate((sim7_2hr, sim7_4hr), axis=0) + sim8_full = np.concatenate((sim8_2hr, sim8_4hr), axis=0) + true_masked = np.concatenate((sim1_full,sim2_full,sim3_full,sim4_full, + sim5_full,sim6_full,sim7_full,sim8_full),axis=0) + + sens1_2hr = pred_masked_2hr[intervals_2hr[0][0]:intervals_2hr[0][1]] + sens2_2hr = pred_masked_2hr[intervals_2hr[1][0]:intervals_2hr[1][1]] + sens3_2hr = pred_masked_2hr[intervals_2hr[2][0]:intervals_2hr[2][1]] + sens4_2hr = pred_masked_2hr[intervals_2hr[3][0]:intervals_2hr[3][1]] + sens5_2hr = pred_masked_2hr[intervals_2hr[4][0]:intervals_2hr[4][1]] + sens6_2hr = pred_masked_2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] + sens7_2hr = pred_masked_2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] + sens8_2hr = pred_masked_2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] + sens1_4hr = pred_masked_4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] + sens2_4hr = pred_masked_4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] + sens3_4hr = pred_masked_4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] + sens4_4hr = pred_masked_4hr[intervals_4hr[3][0]:intervals_4hr[3][1]] + sens5_4hr = pred_masked_4hr[intervals_4hr[4][0]:intervals_4hr[4][1]] + sens6_4hr = pred_masked_4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] + sens7_4hr = pred_masked_4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] + sens8_4hr = pred_masked_4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] + sens1_full = np.concatenate((sens1_2hr, sens1_4hr), axis=0) + sens2_full = np.concatenate((sens2_2hr, sens2_4hr), axis=0) + sens3_full = np.concatenate((sens3_2hr, sens3_4hr), axis=0) + sens4_full = np.concatenate((sens4_2hr, sens4_4hr), axis=0) + sens5_full = np.concatenate((sens5_2hr, sens5_4hr), axis=0) + sens6_full = np.concatenate((sens6_2hr, sens6_4hr), axis=0) + sens7_full = np.concatenate((sens7_2hr, sens7_4hr), axis=0) + sens8_full = np.concatenate((sens8_2hr, sens8_4hr), axis=0) + pred_masked = np.concatenate((sens1_full,sens2_full,sens3_full,sens4_full, + sens5_full,sens6_full,sens7_full,sens8_full),axis=0) + + + maxes = [np.abs(true_masked[i, :]).max() for i in range(len(true_masked))] + abs_errs = [np.abs(true_masked[i, :] - pred_masked[i, :]) for i in + range(len(true_masked))] + emax = [(abs_errs[i] / maxes[i]).max().item() for i in range(len(abs_errs))] + ratios = [(abs_errs[i] / maxes[i]) for i in range(len(abs_errs))] + emean = [ratios[i].flatten()[np.where(np.abs(true_masked[i].flatten()) > 1e-4)].mean().item() for i in range(len(abs_errs))] + + fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(14,6)) + plot_points_2hr = dataloader1.dataset.train_ind + plot_points_4hr = dataloader2.dataset.train_ind + plot_points_4hr = plot_points_4hr + 1160 + plot_points = torch.cat((plot_points_2hr,plot_points_4hr)) + + ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) + frame_ids_2hr = dataloader1.dataset.time_idx.astype(int) + frame_ids_4hr = dataloader2.dataset.time_idx.astype(int) + frame_ids = frame_ids_2hr + frame_ids_4hr - 289 + + if data_type == 'training': + ax.plot(emean) + ax.scatter(plot_points, torch.tensor(emean)[plot_points.long()], c='r', marker='o', s=5) + + text = r'\textbf{Average Error = }' + r'\textbf{' + str(round(torch.tensor(emean).mean().item(), 3)) + r'}' + ax.title.set_text(text) + #ax.set(xlabel=r"\textbf{Frame Index}", ylabel=r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$") + ax.set_xlabel(r"\textbf{Frame Index}") + ax.set_ylabel(r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$", rotation='horizontal', + horizontalalignment='right', verticalalignment='center',fontsoze=16) + + + ax.vlines(frame_ids, 0, np.max(emean), color='g') + ax.legend([r'\textbf{all data}', r'\textbf{training data}',r'\textbf{epi split}']) + # ax.title.set_text(f'$\epsilon$ = {torch.tensor(emean).nanmean():0.3}') + # ax.set(xlabel="Frame Index", ylabel=r"$\frac{|h-\hat{h}|}{\max(|h|)}$") + # ax.vlines(frame_ids, 0, np.max(emean), color='g') + # ax.legend(['all data', 'training data', 'epi split']) + + else: + ax.plot(emean) + text = r'\textbf{Average Error = }' + r'\textbf{' + str(round(torch.tensor(emean).mean().item(), 3)) + r'}' + ax.title.set_text(text) + #ax.set(xlabel=r"\textbf{Frame Index}", ylabel=r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$") + + ax.set_xlabel(r"\textbf{Frame Index}") + ax.set_ylabel(r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$", rotation='horizontal', + horizontalalignment='right', verticalalignment='center',fontsize=16) + + + ax.vlines(frame_ids, 0, np.max(emean), color='g') + ax.legend([r'\textbf{all data}', r'\textbf{epi split}']) + # ax.title.set_text(f'$\epsilon$ = {torch.tensor(emean).nanmean():0.3}') + # ax.set(xlabel="Frame Index", ylabel=r"$\frac{|h-\hat{h}|}{\max(|h|)}$") + # ax.vlines(frame_ids, 0, np.max(emean), color='g') + # ax.legend(['all data', 'epi split']) + + plt.tight_layout() + if path: + plt.savefig(path + '/all_ts_{}'.format(data_type)+"_new.png") + plt.close() + + print(f'The mean mean err is {torch.tensor(emean).nanmean():0.3}') + print(f'The mean max err is {torch.tensor(emax).nanmean():0.3}') + + sim_times = dataloader1.dataset.time_idx + dataloader2.dataset.time_idx + plot_err_for_each_epi(emean, sim_times, data_type, path, plot_points) + + mat_data = sio.loadmat(os.getcwd()+"/Data/tsunami/"+matname) + sensor_loc_indices = mat_data['sensor_loc_indices'] + sensor_locs = mat_data['sensor_locs'] + + if not os.path.exists(path + 'timeseries/{}'.format(data_type)): + if not os.path.exists(path + 'timeseries/'): + os.mkdir(path + 'timeseries/') + os.mkdir(path + 'timeseries/{}'.format(data_type)) + + j = 0 + for idx in sensor_loc_indices[0]: + true_ts = true_masked[:,idx] + pred_ts = pred_masked[:,idx] + + if np.abs(true_ts).max() >= 1e-1: + fig, ax = plt.subplots(figsize=(14,6)) + ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) + + ax.plot(true_ts,c='g') + ax.plot(pred_ts,c='k') + ax.plot(np.abs(true_ts - pred_ts), c='r') + + lon_deg = sensor_locs[j][1] * 180 / np.pi - 180 + lat_deg = sensor_locs[j][0] * 180 / np.pi + txt = r'\textbf{Wave Height at Sensor (}' + r'\textbf{' + str( + round(lon_deg, 1)) + r'}' + r'\textbf{$^\circ$E, }' + r'\textbf{' + str( + round(lat_deg, 1)) + r'}' + r'\textbf{$^\circ$N)}' + ax.title.set_text(txt) + ax.set(xlabel=r"\textbf{Frame Index}", ylabel=r"\textbf{Normalized Wave Height (m)}") + ax.legend([r'\textbf{True Height}', r'\textbf{Predicted Height}', r'\textbf{Absolute Error}']) + + # ax.title.set_text('Wave Height Time Series at {}, {}'.format(round(lon_deg), round(lat_deg))) + # ax.set(xlabel="Frame Index", ylabel="Normalized Wave Height (m)") + # ax.legend(['True Height', 'Predicted Height', 'Absolute Error']) + + plt.tight_layout() + + plt.savefig(path + 'timeseries/{}'.format(data_type)+"/sens_{}_{}_all_epis".format(round(lon_deg),round(lat_deg))+"_new.png") + plt.close() + + j += 1 + + plot_timeseries_for_each_epi_unstruct(true_masked, pred_masked, sim_times, sensor_loc_indices, sensor_locs, data_type, path) + + ### PHYSICAL CONSISTENCY PLOTS ### + n_sims = len(dataloader1.dataset.time_idx) + train_ind_start = np.append(0, dataloader1.dataset.time_idx[0:n_sims - 1]) + for i in range(8): + train_ind_start[i] += 144*i + combined_times = dataloader1.dataset.time_idx + dataloader2.dataset.time_idx + train_inds = torch.concatenate([torch.range(train_ind_start[i], combined_times[i] - 2) for i in range(8)]) + curr_inds = torch.add(train_inds, 1).long() + prev_inds = train_inds.long() + + truediv_2hr = dataloader1.dataset.divu.to('cpu') + truediv_4hr = dataloader2.dataset.divu.to('cpu') + sim1div_2hr = truediv_2hr[intervals_2hr[0][0]:intervals_2hr[0][1]] + sim2div_2hr = truediv_2hr[intervals_2hr[1][0]:intervals_2hr[1][1]] + sim3div_2hr = truediv_2hr[intervals_2hr[2][0]:intervals_2hr[2][1]] + sim4div_2hr = truediv_2hr[intervals_2hr[3][0]:intervals_2hr[3][1]] + sim5div_2hr = truediv_2hr[intervals_2hr[4][0]:intervals_2hr[4][1]] + sim6div_2hr = truediv_2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] + sim7div_2hr = truediv_2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] + sim8div_2hr = truediv_2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] + sim1div_4hr = truediv_4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] + sim2div_4hr = truediv_4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] + sim3div_4hr = truediv_4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] + sim4div_4hr = truediv_4hr[intervals_4hr[3][0]:intervals_4hr[3][1]] + sim5div_4hr = truediv_4hr[intervals_4hr[4][0]:intervals_4hr[4][1]] + sim6div_4hr = truediv_4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] + sim7div_4hr = truediv_4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] + sim8div_4hr = truediv_4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] + sim1div_full = np.concatenate((sim1div_2hr, sim1div_4hr), axis=0) + sim2div_full = np.concatenate((sim2div_2hr, sim2div_4hr), axis=0) + sim3div_full = np.concatenate((sim3div_2hr, sim3div_4hr), axis=0) + sim4div_full = np.concatenate((sim4div_2hr, sim4div_4hr), axis=0) + sim5div_full = np.concatenate((sim5div_2hr, sim5div_4hr), axis=0) + sim6div_full = np.concatenate((sim6div_2hr, sim6div_4hr), axis=0) + sim7div_full = np.concatenate((sim7div_2hr, sim7div_4hr), axis=0) + sim8div_full = np.concatenate((sim8div_2hr, sim8div_4hr), axis=0) + true_div = np.concatenate((sim1div_full, sim2div_full, sim3div_full, sim4div_full, + sim5div_full, sim6div_full, sim7div_full, sim8div_full), axis=0) + + true2hr = dataloader1.dataset.data.to('cpu') + true4hr = dataloader2.dataset.data.to('cpu') + pred2hr = pred1.to('cpu') + pred4hr = pred2.to('cpu') + sim1_2hr = true2hr[intervals_2hr[0][0]:intervals_2hr[0][1]] + sim2_2hr = true2hr[intervals_2hr[1][0]:intervals_2hr[1][1]] + sim3_2hr = true2hr[intervals_2hr[2][0]:intervals_2hr[2][1]] + sim4_2hr = true2hr[intervals_2hr[3][0]:intervals_2hr[3][1]] + sim5_2hr = true2hr[intervals_2hr[4][0]:intervals_2hr[4][1]] + sim6_2hr = true2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] + sim7_2hr = true2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] + sim8_2hr = true2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] + sim1_4hr = true4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] + sim2_4hr = true4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] + sim3_4hr = true4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] + sim4_4hr = true4hr[intervals_4hr[3][0]:intervals_4hr[3][1]] + sim5_4hr = true4hr[intervals_4hr[4][0]:intervals_4hr[4][1]] + sim6_4hr = true4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] + sim7_4hr = true4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] + sim8_4hr = true4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] + sim1_full = np.concatenate((sim1_2hr, sim1_4hr), axis=0) + sim2_full = np.concatenate((sim2_2hr, sim2_4hr), axis=0) + sim3_full = np.concatenate((sim3_2hr, sim3_4hr), axis=0) + sim4_full = np.concatenate((sim4_2hr, sim4_4hr), axis=0) + sim5_full = np.concatenate((sim5_2hr, sim5_4hr), axis=0) + sim6_full = np.concatenate((sim6_2hr, sim6_4hr), axis=0) + sim7_full = np.concatenate((sim7_2hr, sim7_4hr), axis=0) + sim8_full = np.concatenate((sim8_2hr, sim8_4hr), axis=0) + true = np.concatenate((sim1_full, sim2_full, sim3_full, sim4_full, + sim5_full, sim6_full, sim7_full, sim8_full), axis=0) + sens1_2hr = pred2hr[intervals_2hr[0][0]:intervals_2hr[0][1]] + sens2_2hr = pred2hr[intervals_2hr[1][0]:intervals_2hr[1][1]] + sens3_2hr = pred2hr[intervals_2hr[2][0]:intervals_2hr[2][1]] + sens4_2hr = pred2hr[intervals_2hr[3][0]:intervals_2hr[3][1]] + sens5_2hr = pred2hr[intervals_2hr[4][0]:intervals_2hr[4][1]] + sens6_2hr = pred2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] + sens7_2hr = pred2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] + sens8_2hr = pred2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] + sens1_4hr = pred4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] + sens2_4hr = pred4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] + sens3_4hr = pred4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] + sens4_4hr = pred4hr[intervals_4hr[3][0]:intervals_4hr[3][1]] + sens5_4hr = pred4hr[intervals_4hr[4][0]:intervals_4hr[4][1]] + sens6_4hr = pred4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] + sens7_4hr = pred4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] + sens8_4hr = pred4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] + sens1_full = np.concatenate((sens1_2hr, sens1_4hr), axis=0) + sens2_full = np.concatenate((sens2_2hr, sens2_4hr), axis=0) + sens3_full = np.concatenate((sens3_2hr, sens3_4hr), axis=0) + sens4_full = np.concatenate((sens4_2hr, sens4_4hr), axis=0) + sens5_full = np.concatenate((sens5_2hr, sens5_4hr), axis=0) + sens6_full = np.concatenate((sens6_2hr, sens6_4hr), axis=0) + sens7_full = np.concatenate((sens7_2hr, sens7_4hr), axis=0) + sens8_full = np.concatenate((sens8_2hr, sens8_4hr), axis=0) + pred = np.concatenate((sens1_full, sens2_full, sens3_full, sens4_full, + sens5_full, sens6_full, sens7_full, sens8_full), axis=0) + + + + pred_dhdt = (pred[curr_inds,:,:]-pred[prev_inds,:,:])/50 + swe_dhdt = (true[curr_inds,:,:]-true[prev_inds,:,:])/50 + true_div_prev = true_div[prev_inds,:,:] + true_div_curr = true_div[curr_inds,:,:] + true_div = -.5*(true_div_prev+true_div_curr) + times = true_div.shape[0] + full_mask = mask.repeat([times, 1]) + true_div_masked = torch.where((full_mask == 0), torch.from_numpy(true_div[..., 0]), 0) + pred_dhdt_masked = torch.where((full_mask == 0), torch.from_numpy(pred_dhdt[..., 0]), 0) + swe_dhdt_masked = torch.where((full_mask == 0), torch.from_numpy(swe_dhdt[..., 0]), 0) + maxes = [true_div_masked[i, :].abs().max() for i in range(len(true_div_masked))] + abs_errs = [(true_div_masked[i, :] - pred_dhdt_masked[i, :]).abs() for i in + range(len(true_div_masked))] + ratios = [(abs_errs[i] / maxes[i]) for i in range(len(abs_errs))] + emean = [ratios[i].flatten()[torch.where(ratios[i].flatten().abs() > 1e-8)].mean() for i in + range(len(abs_errs))] + swe_abs_errs = [(true_div_masked[i, :] - swe_dhdt_masked[i, :]).abs() for i in + range(len(true_div_masked))] + swe_emax = [(swe_abs_errs[i] / maxes[i]).max().item() for i in range(len(swe_abs_errs))] + swe_ratios = [(swe_abs_errs[i] / maxes[i]) for i in range(len(swe_abs_errs))] + swe_emean = [swe_ratios[i].flatten()[torch.where(torch.from_numpy(true_masked[i]).flatten().abs() > 1e-4)].mean() for i in + range(len(abs_errs))] + fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(14,6)) + ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) + ax.plot(emean) + ax.scatter(range(len(swe_emean)),swe_emean, c='r',s=.5) + txt = r'\textbf{Average Error = }' + r'\textbf{' + str(round(torch.tensor(emean).mean().item(), 3)) + r'}' + ax.title.set_text(txt) + #ax.set(xlabel=r"\textbf{Frame Index}", ylabel=r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$") + ax.set_xlabel(r"\textbf{Frame Index}") + ax.set_ylabel(r"$\mathbf{\frac{|h_t-\hat{h_t}|}{\max(|h_t|)}}$", rotation='horizontal', + horizontalalignment='right', verticalalignment='center',fontsize=16) + ax.vlines(frame_ids, 0, max(np.max(swe_emean),np.max(emean)), color='g') + ax.legend([r'\textbf{Senseiver Err}', r'\textbf{SWE Err}', r'\textbf{epi split}']) + # ax.title.set_text(f'$\epsilon$ = {torch.tensor(emean).mean():0.3}') + # ax.set(xlabel="Frame Index", ylabel=r"$\frac{|h-\hat{h}|}{\max(|h|)}$") + # ax.vlines(frame_ids, 0, max(np.max(swe_emean), np.max(emean)), color='g') + # ax.legend(['Senseiver Err', 'SWE Err', 'epi split']) + + plt.tight_layout() + if path: + plt.savefig(path + '/divergence_all_ts_{}'.format(data_type) + "_new.png") + plt.close() + + print(f'The mean mean div err is {torch.tensor(emean).mean():0.3}') + print(f'The mean max div err is {torch.tensor(swe_emax).mean():0.3}') + + plot_div_err_for_each_epi(emean, swe_emean, sim_times, data_type, path) + + + + + From 210fb07198d6ef1e6592d8d9516563d44f87f658 Mon Sep 17 00:00:00 2001 From: Edward McDugald Date: Wed, 16 Jul 2025 23:19:23 -0700 Subject: [PATCH 3/4] adding scripts for making grl figs --- tsunami/combined_plot.py | 310 ++++++++----- tsunami/make_grl_fig_2.py | 207 +++++++++ tsunami/make_grl_fig_3.py | 665 ++++++++++++++++++++++++++ tsunami/make_grl_fig_4.py | 822 +++++++++++++++++++++++++++++++++ tsunami/quick_plot_combined.py | 24 +- 5 files changed, 1897 insertions(+), 131 deletions(-) create mode 100644 tsunami/make_grl_fig_2.py create mode 100644 tsunami/make_grl_fig_3.py create mode 100644 tsunami/make_grl_fig_4.py diff --git a/tsunami/combined_plot.py b/tsunami/combined_plot.py index 8bd5e69..4c2e2a9 100644 --- a/tsunami/combined_plot.py +++ b/tsunami/combined_plot.py @@ -6,36 +6,48 @@ import os import scipy.io as sio import numpy as np -from matplotlib.ticker import FormatStrFormatter +from scipy.signal import medfilt +plt.rcParams['text.latex.preamble'] = r'\usepackage{bm}' + plt.rcParams['text.usetex'] = True plt.rcParams["font.family"] = "serif" -plt.rcParams["font.size"]=16 -plt.rcParams['xtick.labelsize']=16 -plt.rcParams['ytick.labelsize']=16 -plt.rcParams['text.latex.preamble'] = r'\usepackage{bm}' +plt.rcParams["font.size"] = 16 +plt.rcParams['xtick.labelsize'] = 16 +plt.rcParams['ytick.labelsize'] = 16 +plt.rcParams['font.weight'] = 'bold' +plt.rcParams['axes.titlesize'] = 16 +plt.rcParams['figure.titlesize'] = 16 +plt.rcParams['axes.labelsize'] = 16 +plt.rcParams['font.weight'] = 'bold' +plt.rcParams['axes.labelweight'] = 'bold' +plt.rcParams['axes.titleweight'] = 'bold' -def training_lons(): - return np.array([136.6180, 139.5560, 139.3290, 138.9350, 140.9290, 135.7400, 141.5010, 142.3870]) -def training_lats(): - return np.array([33.0700, 28.8560, 28.9320, 29.3840, 33.4530, 33.1570, 35.9360, 35.2670]) -def unseen_lons(): - return np.array([136.6500, 140.2000, 138.9000, 139.5000]) +def training_lons(): + return np.array([136.6180, 137.9440, 141.4890, 141.7020, + 142.2410, 142.4220, 143.0040, + 143.3780, 143.4350, 143.8990, + 136.8940]) -def unseen_lats(): - return np.array([33.1000, 29.1000, 28.1000, 28.8000]) +def training_lats(): + return np.array([33.0700, 33.1670, 34.5520, 36.1940, + 37.1470, 40.5910, 38.6600, + 40.1760, 40.8600, 42.0840, + 33.6830]) def unseen_lons_new(): - return np.array([136.6500,138.2000,138.9000, - 139.5000,140.2000,140.5000, - 141.5000,142.5000]) + return np.array([140.8270, 141.4590, 142.0500, 142.5420, + 143.2280, 144.0600, 142.6190, + 143.4160,135.9050,137.0710, + 138.0250]) def unseen_lats_new(): - return np.array([33.1000,31.0000,28.1000, - 28.8000,29.1000,31.8000, - 34.2000,36.2000]) + return np.array([33.3620, 36.5340, 41.0190, 34.7450, + 39.8690, 40.2320, 37.8120, + 41.4150, 33.1230, 33.1840, + 34.1750]) @@ -52,7 +64,7 @@ def plot_err_for_each_epi(means, times, dta_type, path, plot_pts): j=0 for i in range(1, len(time_idxs)): - fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7.2,6)) + fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 4.5)) emean = means[int(time_idxs[i - 1]):int(time_idxs[i])] epi_lon = lons[j] epi_lat = lats[j] @@ -62,12 +74,9 @@ def plot_err_for_each_epi(means, times, dta_type, path, plot_pts): plot_points = pts[(pts > start_plt_pts) & (pts < end_plt_pts)] - start_plt_pts ax.plot(emean) ax.scatter(plot_points*(5./6.), torch.tensor(emean)[plot_points], c='r', marker='o', s=5) - ax.title.set_text(f'Average Error = {torch.tensor(emean).mean():0.3}') - #ax.set(xlabel=r"\textbf{Time Steps}", ylabel="$\mathbf{|true-pred|/\max(|true|)}$") - - ax.set_xlabel(r"\textbf{Time Steps}") - ax.set_ylabel(r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$", rotation='horizontal', - horizontalalignment='right', verticalalignment='center',fontsize=16) + txt = r'\textbf{Average Error = }' + r'\textbf{' + str( + round(torch.tensor(emean).mean().item(), 3)) + r'}' + ax.set(xlabel=r"\textbf{Time Steps}", ylabel=txt) ax.legend(['all data', 'training data', '.10 threshold', '.05 threshold']) @@ -85,37 +94,33 @@ def plot_err_for_each_epi(means, times, dta_type, path, plot_pts): else: lt_idx = None # If no such index exists - text = r'\textbf{Average Error = }' + r'\textbf{'+str(round(torch.tensor(emean).mean().item(),3)) + r'}' - ax.title.set_text(text) - # ax.title.set_text(f'Average Error = {torch.tensor(emean).mean():0.3}') - ax.scatter(plot_pts[90],emean[90],s=75,marker="x",color='r') - ax.scatter(plot_pts[180], emean[180], s=75, marker="x", color='r') - ax.scatter(plot_pts[270], emean[270], s=75, marker="x", color='r') - ax.scatter(plot_pts[lt_idx], emean[lt_idx], s=75, marker="x", color='g') + text = r'\textbf{Average Error = }' + r'\textbf{' + str( + round(torch.tensor(emean).mean().item(), 3)) + r'}' + ax.scatter(plot_pts[90],emean[90],s=300,marker="x",color='r') + ax.scatter(plot_pts[180], emean[180], s=300, marker="x", color='r') + ax.scatter(plot_pts[270], emean[270], s=300, marker="x", color='r') + ax.scatter(plot_pts[lt_idx], emean[lt_idx], s=300, marker="x", color='g') ax.set_xlim(0,240) ax.set_ylim(0,np.max(emean)) - # ax.set(xlabel="Time (minutes)", ylabel=r"$\frac{|h-\hat{h}|}{\max(|h|)}$") - # ax.legend(['_', - # r'Error at $75$ mins: ${}$'.format(round(emean[90],3)), - # r'Error at $150$ mins: ${}$'.format(round(emean[180],3)), - # r'Error at 225 mins: ${}$'.format(round(emean[270],3)), - # r'$10\%$ threshold time: ${}$ mins'.format(round(plot_pts[lt_idx],1))]) - #ax.set(xlabel=r"\textbf{Time (minutes)}", ylabel=r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$") ax.set_xlabel(r"\textbf{Time (minutes)}") - ax.set_ylabel(r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$", rotation='horizontal', - horizontalalignment='right', verticalalignment='center',fontsize=16) - - ax.legend(['_', - r'\textbf{Error at $75$ mins: ' + r'\textbf{' + str(round(emean[90],3)) + r'}', - r'\textbf{Error at $150$ mins: ' + r'\textbf{' + str(round(emean[180], 3)) + r'}', - r'\textbf{Error at $225$ mins: ' + r'\textbf{' + str(round(emean[270], 3)) + r'}', - r'\textbf{$10\%$ threshold time: }' + r'\textbf{' +str(round(plot_pts[lt_idx], 1)) + r'}' + r'\textbf{ mins}']) + ax.set_ylabel(text) + labels = [ + '_', + r'\textbf{Error at }$\mathbf{75}$\textbf{ mins: }$\mathbf{' + str(round(emean[90], 3)) + '}$', + r'\textbf{Error at }$\mathbf{150}$\textbf{ mins: }$\mathbf{' + str(round(emean[180], 3)) + '}$', + r'\textbf{Error at }$\mathbf{225}$\textbf{ mins: }$\mathbf{' + str(round(emean[270], 3)) + '}$', + r'\textbf{10\% threshold time: }$\mathbf{' + str(round(plot_pts[lt_idx], 1)) + r'}$\textbf{ mins}' + ] + ax.legend(labels) + for label in ax.get_xticklabels() + ax.get_yticklabels(): + label.set_fontweight('bold') j += 1 plt.tight_layout() if path: plt.savefig(path + '/all_ts_{}'.format(dta_type)+"_epi_{}_{}".format(epi_lon,epi_lat)+"_new.png",dpi=400,bbox_inches='tight') + plt.close() @@ -136,10 +141,10 @@ def plot_timeseries_for_each_epi_unstruct(true,pred,times,sens_idxs,sens_locs,dt for idx in sens_idxs[0]: true_ts = true[int(time_idxs[i - 1]):int(time_idxs[i]), idx] pred_ts = pred[int(time_idxs[i - 1]):int(time_idxs[i]), idx] - if np.abs(true_ts).max() >= 5e-2: + if np.abs(true_ts).max() >= 3e-2: times = np.arange(0,len(true_ts),1)*(5/6) times = times - fig, ax = plt.subplots(figsize=(7.2,6)) + fig, ax = plt.subplots(figsize=(7, 4.5)) ax.plot(times,true_ts, c='g') ax.plot(times,pred_ts, c='k') @@ -147,22 +152,82 @@ def plot_timeseries_for_each_epi_unstruct(true,pred,times,sens_idxs,sens_locs,dt lat_deg = sens_locs[j][1] * 180 / np.pi ax.set_xlim(0, 240) - # ax.title.set_text('Wave Height Time Series at sensor ({}$^\circ$E, {}$^\circ$N)'.format(round(lon_deg,1), round(lat_deg,1))) - # ax.set(xlabel="Time (minutes)", ylabel="Wave Height (m)") - # ax.legend(['True Height', 'Predicted Height']) - txt = r'\textbf{Wave Height at Sensor (}' + r'\textbf{' + str( + txt = r"\textbf{Height (m) at}" + r'\textbf{ (}' + r'\textbf{' + str( round(lon_deg, 1)) + r'}' + r'\textbf{$^\circ$E, }' + r'\textbf{' + str( round(lat_deg, 1)) + r'}' + r'\textbf{$^\circ$N)}' - # txt = r'\textbf{Wave Height Time Series at Sensor (}' + r'\textbf{' + str(round(lon_deg,1))+r'}' + r'\textbf{$^\circ$E}, ' - # + r'\textbf{'+str(round(lat_deg,1))+r'}'+r'\textbf{$^\circ$N)}' - ax.title.set_text(txt) - ax.set(xlabel=r"\textbf{Time (minutes)}", ylabel=r"\textbf{Wave Height (m)}") + ax.set(xlabel=r"\textbf{Time (minutes)}", ylabel=txt) ax.legend([r'\textbf{True Height}', r'\textbf{Predicted Height}']) + for label in ax.get_xticklabels() + ax.get_yticklabels(): + label.set_fontweight('bold') + + plt.tight_layout() + plt.savefig(path + 'timeseries/{}'.format(dta_type) + + "/sens_{}_{}_epi_{}_{}".format(round(lon_deg, 3), round(lat_deg, 3), + round(lons[i - 1], 3), round(lats[i - 1], 3)) + "_new.png", + dpi=400, bbox_inches='tight') + plt.close() + + fig, ax = plt.subplots(figsize=(7, 4.5)) + pred_ts_filtered = medfilt(pred_ts, 9) + pred_ts_filtered = medfilt(pred_ts_filtered, 3) + ax.plot(times, true_ts, c='g') + ax.plot(times, pred_ts_filtered, c='k') + + lon_deg = sens_locs[j][0] * 180 / np.pi + lat_deg = sens_locs[j][1] * 180 / np.pi + ax.set_xlim(0, 240) + + txt = r"\textbf{Height (m) at}" + r'\textbf{ (}' + r'\textbf{' + str( + round(lon_deg, 1)) + r'}' + r'\textbf{$^\circ$E, }' + r'\textbf{' + str( + round(lat_deg, 1)) + r'}' + r'\textbf{$^\circ$N)}' + ax.set(xlabel=r"\textbf{Time (minutes)}", ylabel=txt) + ax.legend([r'\textbf{True Height}', r'\textbf{Predicted Height}']) + for label in ax.get_xticklabels() + ax.get_yticklabels(): + label.set_fontweight('bold') plt.tight_layout() plt.savefig(path + 'timeseries/{}'.format(dta_type) - + "/sens_{}_{}_epi_{}_{}".format(round(lon_deg,3),round(lat_deg,3),round(lons[i-1],3),round(lats[i-1],3)) + "_new.png",dpi=400,bbox_inches='tight') + + "/sens_{}_{}_epi_{}_{}".format(round(lon_deg, 3), round(lat_deg, 3), + round(lons[i - 1], 3), + round(lats[i - 1], 3)) + "_filtered_v2.png", + dpi=400, bbox_inches='tight') + plt.close() + + ###################################### + + fig, ax = plt.subplots(figsize=(7, 4.5)) + pred_ts_filtered = medfilt(pred_ts,21) + pred_ts_filtered = medfilt(pred_ts_filtered, 19) + pred_ts_filtered = medfilt(pred_ts_filtered, 17) + pred_ts_filtered = medfilt(pred_ts_filtered, 15) + pred_ts_filtered = medfilt(pred_ts_filtered, 13) + pred_ts_filtered = medfilt(pred_ts_filtered, 11) + pred_ts_filtered = medfilt(pred_ts_filtered, 9) + pred_ts_filtered = medfilt(pred_ts_filtered, 7) + pred_ts_filtered = medfilt(pred_ts_filtered, 5) + pred_ts_filtered = medfilt(pred_ts_filtered, 3) + ax.plot(times, true_ts, c='g') + ax.plot(times, pred_ts_filtered, c='k') + + lon_deg = sens_locs[j][0] * 180 / np.pi + lat_deg = sens_locs[j][1] * 180 / np.pi + ax.set_xlim(0, 240) + + txt = r"\textbf{Height (m) at}" + r'\textbf{ (}' + r'\textbf{' + str( + round(lon_deg, 1)) + r'}' + r'\textbf{$^\circ$E, }' + r'\textbf{' + str( + round(lat_deg, 1)) + r'}' + r'\textbf{$^\circ$N)}' + ax.set(xlabel=r"\textbf{Time (minutes)}", ylabel=txt) + ax.legend([r'\textbf{True Height}', r'\textbf{Predicted Height}']) + for label in ax.get_xticklabels() + ax.get_yticklabels(): + label.set_fontweight('bold') + + plt.tight_layout() + plt.savefig(path + 'timeseries/{}'.format(dta_type) + + "/sens_{}_{}_epi_{}_{}".format(round(lon_deg, 3), round(lat_deg, 3), + round(lons[i - 1], 3), + round(lats[i - 1], 3)) + "_filtered.png", + dpi=400, bbox_inches='tight') plt.close() j += 1 @@ -184,7 +249,7 @@ def plot_div_err_for_each_epi(means, swemeans, times, dta_type, path): j=0 for i in range(1, len(time_idxs)): - fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(7.2,6)) + fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8.5,4.5)) emean = means[int(time_idxs[i - 1])+8:int(time_idxs[i])] swe_emean = swemeans[int(time_idxs[i - 1])+8:int(time_idxs[i])] times = np.arange(0,len(emean),1)*(5/6) @@ -193,16 +258,9 @@ def plot_div_err_for_each_epi(means, swemeans, times, dta_type, path): epi_lat = lats[j] ax.plot(times, emean, c='g') ax.plot(times, swe_emean, c='k') - # ax.title.set_text(f'Average $h_t$ error = {torch.tensor(emean).mean():0.3}') - # ax.set(xlabel="Time (minutes)", ylabel=r"$\frac{|h_t-\hat{h_t}|}{\max(|h|)}$") - # ax.legend(['Senseiver Error', 'SWE Error']) - txt = r'\textbf{Average }' +r'$\mathbf{h_t}$' + r'\textbf{ Error = }' + r'\textbf{'+str(round(torch.tensor(emean).mean().item(),3)) + r'}' - ax.title.set_text(txt) - #ax.set(xlabel=r"\textbf{Time (minutes)}", ylabel=r"$\mathbf{\frac{|h_t-\hat{h_t}|}{\max(|h|)}}$") + txt = r'\textbf{Average Error = }' + r'\textbf{'+str(round(torch.tensor(emean).mean().item(),3)) + r'}' ax.set_xlabel(r"\textbf{Time (minutes)}") - ax.set_ylabel(r"$\mathbf{\frac{|h_t-\hat{h_t}|}{\max(|h|)}}$", - rotation='horizontal', horizontalalignment='right', - verticalalignment='center',fontsize=16) + ax.set_ylabel(txt) ax.legend([r'\textbf{Senseiver Error}', r'\textbf{SWE Error}']) ax.set_xlim(0,240) @@ -210,6 +268,7 @@ def plot_div_err_for_each_epi(means, swemeans, times, dta_type, path): plt.tight_layout() if path: plt.savefig(path + '/all_ts_div_err_{}'.format(dta_type)+"_epi_{}_{}".format(epi_lon,epi_lat)+"_new.png",dpi=400,bbox_inches='tight') + plt.close() def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname, data_type,path=None): @@ -229,7 +288,7 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname true_masked_4hr = torch.where((full_mask_4hr == 0), true4hr[..., 0], 0) pred_masked_4hr = torch.where((full_mask_4hr == 0), pred4hr[..., 0], 0) intervals_2hr = [(0, 145), (145, 290), (290, 435), (435, 580), (580, 725), (725, 870), (870, 1015), - (1015, 1160)] + (1015, 1160), (1160, 1305), (1305, 1450), (1450, 1595)] sim1_2hr = true_masked_2hr[intervals_2hr[0][0]:intervals_2hr[0][1]] sim2_2hr = true_masked_2hr[intervals_2hr[1][0]:intervals_2hr[1][1]] sim3_2hr = true_masked_2hr[intervals_2hr[2][0]:intervals_2hr[2][1]] @@ -238,8 +297,11 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sim6_2hr = true_masked_2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] sim7_2hr = true_masked_2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] sim8_2hr = true_masked_2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] + sim9_2hr = true_masked_2hr[intervals_2hr[8][0]:intervals_2hr[8][1]] + sim10_2hr = true_masked_2hr[intervals_2hr[9][0]:intervals_2hr[9][1]] + sim11_2hr = true_masked_2hr[intervals_2hr[10][0]:intervals_2hr[10][1]] intervals_4hr = [(0, 144), (144, 288), (288, 432), (432, 576), (576, 720), (720, 864), (864, 1008), - (1008, 1152)] + (1008, 1152), (1152, 1296), (1296, 1440), (1440, 1584)] sim1_4hr = true_masked_4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] sim2_4hr = true_masked_4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] sim3_4hr = true_masked_4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] @@ -248,6 +310,9 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sim6_4hr = true_masked_4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] sim7_4hr = true_masked_4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] sim8_4hr = true_masked_4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] + sim9_4hr = true_masked_4hr[intervals_4hr[8][0]:intervals_4hr[8][1]] + sim10_4hr = true_masked_4hr[intervals_4hr[9][0]:intervals_4hr[9][1]] + sim11_4hr = true_masked_4hr[intervals_4hr[10][0]:intervals_4hr[10][1]] sim1_full = np.concatenate((sim1_2hr, sim1_4hr), axis=0) sim2_full = np.concatenate((sim2_2hr, sim2_4hr), axis=0) sim3_full = np.concatenate((sim3_2hr, sim3_4hr), axis=0) @@ -256,8 +321,12 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sim6_full = np.concatenate((sim6_2hr, sim6_4hr), axis=0) sim7_full = np.concatenate((sim7_2hr, sim7_4hr), axis=0) sim8_full = np.concatenate((sim8_2hr, sim8_4hr), axis=0) + sim9_full = np.concatenate((sim9_2hr, sim9_4hr), axis=0) + sim10_full = np.concatenate((sim10_2hr, sim10_4hr), axis=0) + sim11_full = np.concatenate((sim11_2hr, sim11_4hr), axis=0) true_masked = np.concatenate((sim1_full,sim2_full,sim3_full,sim4_full, - sim5_full,sim6_full,sim7_full,sim8_full),axis=0) + sim5_full,sim6_full,sim7_full,sim8_full, + sim9_full, sim10_full, sim11_full),axis=0) sens1_2hr = pred_masked_2hr[intervals_2hr[0][0]:intervals_2hr[0][1]] sens2_2hr = pred_masked_2hr[intervals_2hr[1][0]:intervals_2hr[1][1]] @@ -267,6 +336,9 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sens6_2hr = pred_masked_2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] sens7_2hr = pred_masked_2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] sens8_2hr = pred_masked_2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] + sens9_2hr = pred_masked_2hr[intervals_2hr[8][0]:intervals_2hr[8][1]] + sens10_2hr = pred_masked_2hr[intervals_2hr[9][0]:intervals_2hr[9][1]] + sens11_2hr = pred_masked_2hr[intervals_2hr[10][0]:intervals_2hr[10][1]] sens1_4hr = pred_masked_4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] sens2_4hr = pred_masked_4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] sens3_4hr = pred_masked_4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] @@ -275,6 +347,9 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sens6_4hr = pred_masked_4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] sens7_4hr = pred_masked_4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] sens8_4hr = pred_masked_4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] + sens9_4hr = pred_masked_4hr[intervals_4hr[8][0]:intervals_4hr[8][1]] + sens10_4hr = pred_masked_4hr[intervals_4hr[9][0]:intervals_4hr[9][1]] + sens11_4hr = pred_masked_4hr[intervals_4hr[10][0]:intervals_4hr[10][1]] sens1_full = np.concatenate((sens1_2hr, sens1_4hr), axis=0) sens2_full = np.concatenate((sens2_2hr, sens2_4hr), axis=0) sens3_full = np.concatenate((sens3_2hr, sens3_4hr), axis=0) @@ -283,8 +358,12 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sens6_full = np.concatenate((sens6_2hr, sens6_4hr), axis=0) sens7_full = np.concatenate((sens7_2hr, sens7_4hr), axis=0) sens8_full = np.concatenate((sens8_2hr, sens8_4hr), axis=0) + sens9_full = np.concatenate((sens9_2hr, sens9_4hr), axis=0) + sens10_full = np.concatenate((sens10_2hr, sens10_4hr), axis=0) + sens11_full = np.concatenate((sens11_2hr, sens11_4hr), axis=0) pred_masked = np.concatenate((sens1_full,sens2_full,sens3_full,sens4_full, - sens5_full,sens6_full,sens7_full,sens8_full),axis=0) + sens5_full,sens6_full,sens7_full,sens8_full, + sens9_full, sens10_full, sens11_full),axis=0) maxes = [np.abs(true_masked[i, :]).max() for i in range(len(true_masked))] @@ -294,7 +373,7 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname ratios = [(abs_errs[i] / maxes[i]) for i in range(len(abs_errs))] emean = [ratios[i].flatten()[np.where(np.abs(true_masked[i].flatten()) > 1e-4)].mean().item() for i in range(len(abs_errs))] - fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(14,6)) + fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7,3)) plot_points_2hr = dataloader1.dataset.train_ind plot_points_4hr = dataloader2.dataset.train_ind plot_points_4hr = plot_points_4hr + 1160 @@ -311,40 +390,30 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname text = r'\textbf{Average Error = }' + r'\textbf{' + str(round(torch.tensor(emean).mean().item(), 3)) + r'}' ax.title.set_text(text) - #ax.set(xlabel=r"\textbf{Frame Index}", ylabel=r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$") ax.set_xlabel(r"\textbf{Frame Index}") ax.set_ylabel(r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$", rotation='horizontal', - horizontalalignment='right', verticalalignment='center',fontsoze=16) + horizontalalignment='right', verticalalignment='center') ax.vlines(frame_ids, 0, np.max(emean), color='g') ax.legend([r'\textbf{all data}', r'\textbf{training data}',r'\textbf{epi split}']) - # ax.title.set_text(f'$\epsilon$ = {torch.tensor(emean).nanmean():0.3}') - # ax.set(xlabel="Frame Index", ylabel=r"$\frac{|h-\hat{h}|}{\max(|h|)}$") - # ax.vlines(frame_ids, 0, np.max(emean), color='g') - # ax.legend(['all data', 'training data', 'epi split']) else: ax.plot(emean) text = r'\textbf{Average Error = }' + r'\textbf{' + str(round(torch.tensor(emean).mean().item(), 3)) + r'}' ax.title.set_text(text) - #ax.set(xlabel=r"\textbf{Frame Index}", ylabel=r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$") - ax.set_xlabel(r"\textbf{Frame Index}") - ax.set_ylabel(r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$", rotation='horizontal', - horizontalalignment='right', verticalalignment='center',fontsize=16) + ax.set_ylabel(text, rotation='horizontal', + horizontalalignment='right', verticalalignment='center', fontsize=8) ax.vlines(frame_ids, 0, np.max(emean), color='g') ax.legend([r'\textbf{all data}', r'\textbf{epi split}']) - # ax.title.set_text(f'$\epsilon$ = {torch.tensor(emean).nanmean():0.3}') - # ax.set(xlabel="Frame Index", ylabel=r"$\frac{|h-\hat{h}|}{\max(|h|)}$") - # ax.vlines(frame_ids, 0, np.max(emean), color='g') - # ax.legend(['all data', 'epi split']) plt.tight_layout() if path: - plt.savefig(path + '/all_ts_{}'.format(data_type)+"_new.png") + #plt.savefig(path + '/all_ts_{}'.format(data_type)+"_new.pdf") + plt.savefig(path + '/all_ts_{}'.format(data_type) + "_new.png",dpi=400) plt.close() print(f'The mean mean err is {torch.tensor(emean).nanmean():0.3}') @@ -368,7 +437,7 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname pred_ts = pred_masked[:,idx] if np.abs(true_ts).max() >= 1e-1: - fig, ax = plt.subplots(figsize=(14,6)) + fig, ax = plt.subplots(figsize=(7,3)) ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) ax.plot(true_ts,c='g') @@ -377,20 +446,16 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname lon_deg = sensor_locs[j][1] * 180 / np.pi - 180 lat_deg = sensor_locs[j][0] * 180 / np.pi - txt = r'\textbf{Wave Height at Sensor (}' + r'\textbf{' + str( + txt = r'\textbf{Sensor (}' + r'\textbf{' + str( round(lon_deg, 1)) + r'}' + r'\textbf{$^\circ$E, }' + r'\textbf{' + str( round(lat_deg, 1)) + r'}' + r'\textbf{$^\circ$N)}' ax.title.set_text(txt) ax.set(xlabel=r"\textbf{Frame Index}", ylabel=r"\textbf{Normalized Wave Height (m)}") ax.legend([r'\textbf{True Height}', r'\textbf{Predicted Height}', r'\textbf{Absolute Error}']) - - # ax.title.set_text('Wave Height Time Series at {}, {}'.format(round(lon_deg), round(lat_deg))) - # ax.set(xlabel="Frame Index", ylabel="Normalized Wave Height (m)") - # ax.legend(['True Height', 'Predicted Height', 'Absolute Error']) - plt.tight_layout() - plt.savefig(path + 'timeseries/{}'.format(data_type)+"/sens_{}_{}_all_epis".format(round(lon_deg),round(lat_deg))+"_new.png") + plt.savefig(path + 'timeseries/{}'.format(data_type)+"/sens_{}_{}_all_epis".format(round(lon_deg),round(lat_deg))+"_new.png",dpi=400) + plt.close() j += 1 @@ -400,10 +465,10 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname ### PHYSICAL CONSISTENCY PLOTS ### n_sims = len(dataloader1.dataset.time_idx) train_ind_start = np.append(0, dataloader1.dataset.time_idx[0:n_sims - 1]) - for i in range(8): + for i in range(11): train_ind_start[i] += 144*i combined_times = dataloader1.dataset.time_idx + dataloader2.dataset.time_idx - train_inds = torch.concatenate([torch.range(train_ind_start[i], combined_times[i] - 2) for i in range(8)]) + train_inds = torch.concatenate([torch.range(train_ind_start[i], combined_times[i] - 2) for i in range(11)]) curr_inds = torch.add(train_inds, 1).long() prev_inds = train_inds.long() @@ -417,6 +482,9 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sim6div_2hr = truediv_2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] sim7div_2hr = truediv_2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] sim8div_2hr = truediv_2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] + sim9div_2hr = truediv_2hr[intervals_2hr[8][0]:intervals_2hr[8][1]] + sim10div_2hr = truediv_2hr[intervals_2hr[9][0]:intervals_2hr[9][1]] + sim11div_2hr = truediv_2hr[intervals_2hr[10][0]:intervals_2hr[10][1]] sim1div_4hr = truediv_4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] sim2div_4hr = truediv_4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] sim3div_4hr = truediv_4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] @@ -425,6 +493,9 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sim6div_4hr = truediv_4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] sim7div_4hr = truediv_4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] sim8div_4hr = truediv_4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] + sim9div_4hr = truediv_4hr[intervals_4hr[8][0]:intervals_4hr[8][1]] + sim10div_4hr = truediv_4hr[intervals_4hr[9][0]:intervals_4hr[9][1]] + sim11div_4hr = truediv_4hr[intervals_4hr[10][0]:intervals_4hr[10][1]] sim1div_full = np.concatenate((sim1div_2hr, sim1div_4hr), axis=0) sim2div_full = np.concatenate((sim2div_2hr, sim2div_4hr), axis=0) sim3div_full = np.concatenate((sim3div_2hr, sim3div_4hr), axis=0) @@ -433,8 +504,12 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sim6div_full = np.concatenate((sim6div_2hr, sim6div_4hr), axis=0) sim7div_full = np.concatenate((sim7div_2hr, sim7div_4hr), axis=0) sim8div_full = np.concatenate((sim8div_2hr, sim8div_4hr), axis=0) + sim9div_full = np.concatenate((sim9div_2hr, sim9div_4hr), axis=0) + sim10div_full = np.concatenate((sim10div_2hr, sim10div_4hr), axis=0) + sim11div_full = np.concatenate((sim11div_2hr, sim11div_4hr), axis=0) true_div = np.concatenate((sim1div_full, sim2div_full, sim3div_full, sim4div_full, - sim5div_full, sim6div_full, sim7div_full, sim8div_full), axis=0) + sim5div_full, sim6div_full, sim7div_full, sim8div_full, + sim9div_full, sim10div_full, sim11div_full), axis=0) true2hr = dataloader1.dataset.data.to('cpu') true4hr = dataloader2.dataset.data.to('cpu') @@ -448,6 +523,9 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sim6_2hr = true2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] sim7_2hr = true2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] sim8_2hr = true2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] + sim9_2hr = true2hr[intervals_2hr[8][0]:intervals_2hr[8][1]] + sim10_2hr = true2hr[intervals_2hr[9][0]:intervals_2hr[9][1]] + sim11_2hr = true2hr[intervals_2hr[10][0]:intervals_2hr[10][1]] sim1_4hr = true4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] sim2_4hr = true4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] sim3_4hr = true4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] @@ -456,6 +534,9 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sim6_4hr = true4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] sim7_4hr = true4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] sim8_4hr = true4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] + sim9_4hr = true4hr[intervals_4hr[8][0]:intervals_4hr[8][1]] + sim10_4hr = true4hr[intervals_4hr[9][0]:intervals_4hr[9][1]] + sim11_4hr = true4hr[intervals_4hr[10][0]:intervals_4hr[10][1]] sim1_full = np.concatenate((sim1_2hr, sim1_4hr), axis=0) sim2_full = np.concatenate((sim2_2hr, sim2_4hr), axis=0) sim3_full = np.concatenate((sim3_2hr, sim3_4hr), axis=0) @@ -464,8 +545,12 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sim6_full = np.concatenate((sim6_2hr, sim6_4hr), axis=0) sim7_full = np.concatenate((sim7_2hr, sim7_4hr), axis=0) sim8_full = np.concatenate((sim8_2hr, sim8_4hr), axis=0) + sim9_full = np.concatenate((sim9_2hr, sim9_4hr), axis=0) + sim10_full = np.concatenate((sim10_2hr, sim10_4hr), axis=0) + sim11_full = np.concatenate((sim11_2hr, sim11_4hr), axis=0) true = np.concatenate((sim1_full, sim2_full, sim3_full, sim4_full, - sim5_full, sim6_full, sim7_full, sim8_full), axis=0) + sim5_full, sim6_full, sim7_full, sim8_full, + sim9_full, sim10_full, sim11_full), axis=0) sens1_2hr = pred2hr[intervals_2hr[0][0]:intervals_2hr[0][1]] sens2_2hr = pred2hr[intervals_2hr[1][0]:intervals_2hr[1][1]] sens3_2hr = pred2hr[intervals_2hr[2][0]:intervals_2hr[2][1]] @@ -474,6 +559,9 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sens6_2hr = pred2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] sens7_2hr = pred2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] sens8_2hr = pred2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] + sens9_2hr = pred2hr[intervals_2hr[8][0]:intervals_2hr[8][1]] + sens10_2hr = pred2hr[intervals_2hr[9][0]:intervals_2hr[9][1]] + sens11_2hr = pred2hr[intervals_2hr[10][0]:intervals_2hr[10][1]] sens1_4hr = pred4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] sens2_4hr = pred4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] sens3_4hr = pred4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] @@ -482,6 +570,9 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sens6_4hr = pred4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] sens7_4hr = pred4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] sens8_4hr = pred4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] + sens9_4hr = pred4hr[intervals_4hr[8][0]:intervals_4hr[8][1]] + sens10_4hr = pred4hr[intervals_4hr[9][0]:intervals_4hr[9][1]] + sens11_4hr = pred4hr[intervals_4hr[10][0]:intervals_4hr[10][1]] sens1_full = np.concatenate((sens1_2hr, sens1_4hr), axis=0) sens2_full = np.concatenate((sens2_2hr, sens2_4hr), axis=0) sens3_full = np.concatenate((sens3_2hr, sens3_4hr), axis=0) @@ -490,8 +581,12 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname sens6_full = np.concatenate((sens6_2hr, sens6_4hr), axis=0) sens7_full = np.concatenate((sens7_2hr, sens7_4hr), axis=0) sens8_full = np.concatenate((sens8_2hr, sens8_4hr), axis=0) + sens9_full = np.concatenate((sens9_2hr, sens9_4hr), axis=0) + sens10_full = np.concatenate((sens10_2hr, sens10_4hr), axis=0) + sens11_full = np.concatenate((sens11_2hr, sens11_4hr), axis=0) pred = np.concatenate((sens1_full, sens2_full, sens3_full, sens4_full, - sens5_full, sens6_full, sens7_full, sens8_full), axis=0) + sens5_full, sens6_full, sens7_full, sens8_full, + sens9_full, sens10_full, sens11_full), axis=0) @@ -517,26 +612,21 @@ def plot_all_ts_tsu_japan_unstruct(dataloader1,dataloader2,pred1, pred2, matname swe_ratios = [(swe_abs_errs[i] / maxes[i]) for i in range(len(swe_abs_errs))] swe_emean = [swe_ratios[i].flatten()[torch.where(torch.from_numpy(true_masked[i]).flatten().abs() > 1e-4)].mean() for i in range(len(abs_errs))] - fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(14,6)) + fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7,3)) ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) ax.plot(emean) ax.scatter(range(len(swe_emean)),swe_emean, c='r',s=.5) txt = r'\textbf{Average Error = }' + r'\textbf{' + str(round(torch.tensor(emean).mean().item(), 3)) + r'}' ax.title.set_text(txt) - #ax.set(xlabel=r"\textbf{Frame Index}", ylabel=r"$\mathbf{\frac{|h-\hat{h}|}{\max(|h|)}}$") ax.set_xlabel(r"\textbf{Frame Index}") ax.set_ylabel(r"$\mathbf{\frac{|h_t-\hat{h_t}|}{\max(|h_t|)}}$", rotation='horizontal', - horizontalalignment='right', verticalalignment='center',fontsize=16) + horizontalalignment='right', verticalalignment='center',fontsize=8) ax.vlines(frame_ids, 0, max(np.max(swe_emean),np.max(emean)), color='g') ax.legend([r'\textbf{Senseiver Err}', r'\textbf{SWE Err}', r'\textbf{epi split}']) - # ax.title.set_text(f'$\epsilon$ = {torch.tensor(emean).mean():0.3}') - # ax.set(xlabel="Frame Index", ylabel=r"$\frac{|h-\hat{h}|}{\max(|h|)}$") - # ax.vlines(frame_ids, 0, max(np.max(swe_emean), np.max(emean)), color='g') - # ax.legend(['Senseiver Err', 'SWE Err', 'epi split']) plt.tight_layout() if path: - plt.savefig(path + '/divergence_all_ts_{}'.format(data_type) + "_new.png") + plt.savefig(path + '/divergence_all_ts_{}'.format(data_type) + "_new.png",dpi=400) plt.close() print(f'The mean mean div err is {torch.tensor(emean).mean():0.3}') diff --git a/tsunami/make_grl_fig_2.py b/tsunami/make_grl_fig_2.py new file mode 100644 index 0000000..c1ccf0b --- /dev/null +++ b/tsunami/make_grl_fig_2.py @@ -0,0 +1,207 @@ +import rasterio +import matplotlib.pyplot as plt +import numpy as np +from rasterio.merge import merge +from datasets import SWETsunamiForPlotting3 +import geopandas as gpd +import scipy.io as sio +import os + +## data source ## +### https://topotools.cr.usgs.gov/gmted_viewer/viewer.htm ### + +### UNCOMMENT BELOW CODE TO MAKE merged.tif file ### +# f1 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/10N090E_20101117_gmted_mea075.tif') +# dta1 = f1.read() +# f2 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/10N120E_20101117_gmted_mea075.tif') +# dta2 = f2.read() +# f3 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/10N150E_20101117_gmted_mea075.tif') +# dta3 = f3.read() +# f4 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/10S090E_20101117_gmted_mea075.tif') +# dta4 = f4.read() +# f5 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/10S120E_20101117_gmted_mea075.tif') +# dta5 = f5.read() +# f6 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/10S150E_20101117_gmted_mea075.tif') +# dta6 = f6.read() +# f7 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/30N090E_20101117_gmted_mea075.tif') +# dta7 = f7.read() +# f8 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/30N120E_20101117_gmted_mea075.tif') +# dta8 = f8.read() +# f9 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/30N150E_20101117_gmted_mea075.tif') +# dta9 = f9.read() +# f10 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/50N090E_20101117_gmted_mea075.tif') +# dta10 = f10.read() +# f11 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/50N120E_20101117_gmted_mea075.tif') +# dta11 = f11.read() +# f12 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/50N150E_20101117_gmted_mea075.tif') +# dta12 = f12.read() +# +# combined_data = merge([f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12]) +# full_data = combined_data[0][0] +# transform = combined_data[1] +# topo_dataset = rasterio.open(os.getcwd()+'/geodata/merged.tif', 'w',driver='GTiff', +# height=full_data.shape[0],width=full_data.shape[1],count=1, +# dtype=full_data.dtype,crs='+proj=latlong',transform=transform) +# topo_dataset.write(full_data, 1) +# topo_dataset.close() + +topo = rasterio.open(os.getcwd()+'/geodata/merged.tif') +topo_data = topo.read() +df = gpd.read_file(os.getcwd()+'/geodata/ne_10m_admin_0_countries.shp') + +chi = df.loc[df['ADMIN'] == 'China'] +jap = df.loc[df['ADMIN'] == 'Japan'] +rus = df.loc[df['ADMIN'] == 'Russia'] +sko = df.loc[df['ADMIN'] == 'South Korea'] +nko = df.loc[df['ADMIN'] == 'North Korea'] +phi = df.loc[df['ADMIN'] == 'Philippines'] +tai = df.loc[df['ADMIN'] == 'Taiwan'] +vie = df.loc[df['ADMIN'] == 'Vietnam'] + + + +def training_lons(): + return np.array([136.6180, 137.9440, 141.4890, 141.7020, + 142.2410, 142.4220, 143.0040, + 143.3780, 143.4350, 143.8990, + 136.8940]) + +def training_lats(): + return np.array([33.0700, 33.1670, 34.5520, 36.1940, + 37.1470, 40.5910, 38.6600, + 40.1760, 40.8600, 42.0840, + 33.6830]) + + +def unseen_lons_new(): + return np.array([140.8270, 141.4590, 142.0500, 142.5420, + 143.2280, 144.0600, 142.6190, 143.4160,135.9050,137.0710,138.0250]) + +def unseen_lats_new(): + return np.array([33.3620, 36.5340, 41.0190, 34.7450, + 39.8690, 40.2320,37.8120,41.4150,33.1230,33.1840,34.1750]) + +min_lat = -10 +max_lat = 70 +min_lon = 90 +max_lon = 180 + +min_lat_plot = 32.5 +max_lat_plot = 45 +min_lon_plot = 132.5 +max_lon_plot = 145 +nlat , nlon = np.shape(topo_data[0]) +lat_vals = np.linspace(max_lat,min_lat,nlat) +lon_vals = np.linspace(min_lon,max_lon,nlon) +min_lat_idx = np.where(lat_vals >= min_lat_plot)[0][-1] +max_lat_idx = np.where(lat_vals >= max_lat_plot)[0][-1] +min_lon_idx = np.where(lon_vals >= min_lon_plot)[0][0] +max_lon_idx = np.where(lon_vals >= max_lon_plot)[0][0] +nlat_plot, nlon_plot = np.shape(topo_data[0][max_lat_idx:min_lat_idx,:]) +lat_vals_new = np.linspace(min_lat_plot,max_lat_plot,nlat_plot) +lon_vals_new = np.linspace(min_lon_plot,max_lon_plot,nlon_plot) +dx = (lon_vals_new[1]-lon_vals_new[0])/2. +dy = (lat_vals_new[1]-lat_vals_new[0])/2. +extent = [lon_vals_new[0]-dx, lon_vals_new[-1]+dx, lat_vals_new[0]-dy, lat_vals_new[-1]+dy] + + +plt.rcParams['text.usetex'] = True +plt.rcParams["font.family"] = "serif" +plt.rcParams["font.size"] = 24 +plt.rcParams['font.weight'] = 'bold' +plt.rcParams['axes.labelweight'] = 'bold' +plt.rcParams['axes.titleweight'] = 'bold' +### PLOT THE EPICENTERS ### +fig , ax = plt.subplots(figsize=(9,9)) +top = topo_data[0][max_lat_idx:min_lat_idx,min_lon_idx:max_lon_idx] + +from numpy.ma import masked_array +top_img = masked_array(top, top == 0.0) +ocn_img = masked_array(top, top != 0.0) +ax.set_facecolor('deepskyblue') +im1 = ax.imshow(top_img,cmap='summer', extent=extent) +im1 = ax.scatter(training_lons(),training_lats(),color='k',s=125, marker='x',linewidths=4) +im2 = ax.scatter(unseen_lons_new(),unseen_lats_new(),color='r',s=125,marker='x',linewidths=4) +ax.set_xlim([extent[0], extent[1]]) +ax.set_ylim([extent[2],extent[3]]) +#fig.text(0.05, 0.90, '(b)', fontsize=24, fontweight='bold', va='top') +plt.legend((im1,im2),('Train epicenters','Test epicenters'),loc='upper left') +plt.grid(color = 'gray', linestyle = '--', linewidth = 0.5) +plt.savefig(os.getcwd()+"/fig2_test_train_epis.png",bbox_inches='tight',dpi=600) +plt.close() + + + + + +######################################### + + +#### PLOT THE SENSORS #### +min_lat_plot = 0 +max_lat_plot = 55 +min_lon_plot = 115 +max_lon_plot = 170 +fname_2hr = "unseen_agg_8_sims_0_time_ss_2_ss_unstruct_ntimes_3464_wd_new_epis_tf_145.mat" +true_data_2hr, latitude_2hr, longitude_2hr, mask_2hr, max_ht_2hr, times_2hr, sensors_2hr, div_2hr = SWETsunamiForPlotting3(fname_2hr) +tsun_lons = longitude_2hr*(180 / np.pi) +tsun_lats = latitude_2hr*(180 / np.pi) +mat_data = sio.loadmat(os.getcwd()+"/Data/tsunami/"+fname_2hr) +sensor_indices = mat_data['sensor_loc_indices'] +sens_lons = tsun_lons[sensor_indices[0]] +sens_lats = tsun_lats[sensor_indices[0]] + +sens_lons_inner = sens_lons[(min_lon_plot<=sens_lons) & (sens_lons <= max_lon_plot) & (min_lat_plot<=sens_lats) & (sens_lats <= max_lat_plot)] +sens_lats_inner = sens_lats[(min_lon_plot<=sens_lons) & (sens_lons <= max_lon_plot) & (min_lat_plot<=sens_lats) & (sens_lats <= max_lat_plot)] + + + +nlat , nlon = np.shape(topo_data[0]) +lat_vals = np.linspace(max_lat,min_lat,nlat) +lon_vals = np.linspace(min_lon,max_lon,nlon) +min_lat_idx = np.where(lat_vals >= min_lat_plot)[0][-1] +max_lat_idx = np.where(lat_vals >= max_lat_plot)[0][-1] +min_lon_idx = np.where(lon_vals >= min_lon_plot)[0][0] +max_lon_idx = np.where(lon_vals >= max_lon_plot)[0][0] +nlat_plot, nlon_plot = np.shape(topo_data[0][max_lat_idx:min_lat_idx,:]) +lat_vals_new = np.linspace(min_lat_plot,max_lat_plot,nlat_plot) +lon_vals_new = np.linspace(min_lon_plot,max_lon_plot,nlon_plot) +dx = (lon_vals_new[1]-lon_vals_new[0])/2. +dy = (lat_vals_new[1]-lat_vals_new[0])/2. +extent = [lon_vals_new[0]-dx, lon_vals_new[-1]+dx, lat_vals_new[0]-dy, lat_vals_new[-1]+dy] +top = topo_data[0][max_lat_idx:min_lat_idx,min_lon_idx:max_lon_idx] + + + +plt.rcParams['text.usetex'] = True +plt.rcParams["font.family"] = "serif" +plt.rcParams["font.size"] = 24 +plt.rcParams['font.weight'] = 'bold' +plt.rcParams['axes.labelweight'] = 'bold' +plt.rcParams['axes.titleweight'] = 'bold' + +top_img = masked_array(top, top == 0.0) +ocn_img = masked_array(top, top != 0.0) + +fig , ax = plt.subplots(figsize=(9,9)) +ax.set_facecolor('deepskyblue') +im3 = ax.imshow(top_img,cmap='summer', extent=extent) +im3 = ax.scatter(sens_lons_inner,sens_lats_inner,color='k',marker="^",s=150) + + +min_lat_plot = 32.5 +max_lat_plot = 45 +min_lon_plot = 132.5 +max_lon_plot = 145 +import matplotlib.patches as patches +rect = patches.Rectangle((min_lon_plot, min_lat_plot), max_lon_plot-min_lon_plot, max_lat_plot-min_lat_plot, + linewidth=4, edgecolor='red', linestyle='--', facecolor='none') +ax.add_patch(rect) +ax.set_xlim([extent[0], extent[1]]) +ax.set_ylim([extent[2],extent[3]]) +ax.set_ylabel(r'Latitude ($^\circ$N)', labelpad=2.5) +ax.set_xlabel(r'Longitude ($^\circ$E)', labelpad=2.5) +ax.legend(['Dart Buoys']) +plt.grid(color = 'gray', linestyle = '--', linewidth = 0.5) +plt.savefig(os.getcwd()+"/fig2_buoys.png",bbox_inches='tight',dpi=600) +plt.close() diff --git a/tsunami/make_grl_fig_3.py b/tsunami/make_grl_fig_3.py new file mode 100644 index 0000000..a8c005d --- /dev/null +++ b/tsunami/make_grl_fig_3.py @@ -0,0 +1,665 @@ +import multiprocessing +multiprocessing.set_start_method("fork") +import os +import torch +import rasterio +import matplotlib.pyplot as plt +import numpy as np +from rasterio.merge import merge +from datasets import SWETsunamiForPlotting +from scipy.interpolate import griddata +import scipy.io as sio +from mpl_toolkits.axes_grid1 import make_axes_locatable +plt.rcParams['text.usetex'] = True +plt.rcParams["font.family"] = "serif" +plt.rcParams["font.size"] = 32 +plt.rcParams['xtick.labelsize'] = 16 +plt.rcParams['ytick.labelsize'] = 16 +plt.rcParams['font.weight'] = 'bold' +plt.rcParams['axes.titlesize'] = 32 +plt.rcParams['figure.titlesize'] = 32 +plt.rcParams['axes.labelsize'] = 32 + +#activate conda env topos + +## data source ## +### https://topotools.cr.usgs.gov/gmted_viewer/viewer.htm ### +# +# f1 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/10N090E_20101117_gmted_mea075.tif') +# dta1 = f1.read() +# f2 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/10N120E_20101117_gmted_mea075.tif') +# dta2 = f2.read() +# f3 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/10N150E_20101117_gmted_mea075.tif') +# dta3 = f3.read() +# f4 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/10S090E_20101117_gmted_mea075.tif') +# dta4 = f4.read() +# f5 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/10S120E_20101117_gmted_mea075.tif') +# dta5 = f5.read() +# f6 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/10S150E_20101117_gmted_mea075.tif') +# dta6 = f6.read() +# f7 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/30N090E_20101117_gmted_mea075.tif') +# dta7 = f7.read() +# f8 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/30N120E_20101117_gmted_mea075.tif') +# dta8 = f8.read() +# f9 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/30N150E_20101117_gmted_mea075.tif') +# dta9 = f9.read() +# f10 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/50N090E_20101117_gmted_mea075.tif') +# dta10 = f10.read() +# f11 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/50N120E_20101117_gmted_mea075.tif') +# dta11 = f11.read() +# f12 = rasterio.open('/Users/emcdugald/tsunseiver/geodata/50N150E_20101117_gmted_mea075.tif') +# dta12 = f12.read() +# # +# combined_data = merge([f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12]) +# full_data = combined_data[0][0] +# transform = combined_data[1] +# topo_dataset = rasterio.open('/Users/emcdugald/tsunseiver/geodata/merged.tif', 'w',driver='GTiff', +# height=full_data.shape[0],width=full_data.shape[1],count=1, +# dtype=full_data.dtype,crs='+proj=latlong',transform=transform) +# topo_dataset.write(full_data, 1) +# topo_dataset.close() + +topo = rasterio.open(os.getcwd()+'/geodata/merged.tif') +topo_data = topo.read() + +import geopandas as gpd +df = gpd.read_file(os.getcwd()+'/geodata/ne_10m_admin_0_countries.shp') +chi = df.loc[df['ADMIN'] == 'China'] +jap = df.loc[df['ADMIN'] == 'Japan'] +rus = df.loc[df['ADMIN'] == 'Russia'] +sko = df.loc[df['ADMIN'] == 'South Korea'] +nko = df.loc[df['ADMIN'] == 'North Korea'] +phi = df.loc[df['ADMIN'] == 'Philippines'] +tai = df.loc[df['ADMIN'] == 'Taiwan'] +vie = df.loc[df['ADMIN'] == 'Vietnam'] +mon = df.loc[df['ADMIN'] == 'Mongolia'] + +min_lat = -10 +max_lat = 70 +min_lon = 90 +max_lon = 180 +min_lat_plot = -8 +max_lat_plot = 65 +min_lon_plot = 110 +max_lon_plot = 180 +nlat , nlon = np.shape(topo_data[0]) +lat_vals = np.linspace(max_lat,min_lat,nlat) +lon_vals = np.linspace(min_lon,max_lon,nlon) +min_lat_idx = np.where(lat_vals >= min_lat_plot)[0][-1] +max_lat_idx = np.where(lat_vals >= max_lat_plot)[0][-1] +min_lon_idx = np.where(lon_vals >= min_lon_plot)[0][0] +max_lon_idx = np.where(lon_vals >= max_lon_plot)[0][0] +nlat_plot, nlon_plot = np.shape(topo_data[0][max_lat_idx:min_lat_idx,min_lon_idx:max_lon_idx]) +lat_vals_new = np.linspace(min_lat_plot,max_lat_plot,nlat_plot) +lon_vals_new = np.linspace(min_lon_plot,max_lon_plot,nlon_plot) +dx = (lon_vals_new[1]-lon_vals_new[0])/2. +dy = (lat_vals_new[1]-lat_vals_new[0])/2. +extent = [lon_vals_new[0]-dx, lon_vals_new[-1]+dx, lat_vals_new[0]-dy, lat_vals_new[-1]+dy] +topo_vals = topo_data[0][max_lat_idx:min_lat_idx,min_lon_idx:max_lon_idx] +print("topo_vals_shape:",np.shape(topo_vals)) +topo_vals = topo_vals[::80,::80] +lat_vals = lat_vals_new[::80] +lon_vals = lon_vals_new[::80] +print("new topo_vals_shape:",np.shape(topo_vals)) + + +def coord_idx(s, ver_num): + if ver_num == 122: + if 0 <= s <= 144: + return 0 + elif 145 <= s <= 289: + return 1 + elif 290 <= s <= 434: + return 2 + elif 435 <= s <= 579: + return 3 + elif 580 <= s <= 724: + return 4 + elif 725 <= s <= 869: + return 5 + elif 870 <= s <= 1014: + return 6 + elif 1015 <= s <= 1159: + return 7 + elif 1160 <= s <= 1304: + return 8 + elif 1305 <= s <= 1449: + return 9 + else: + return 10 + else: + if 0 <= s <= 143: + return 0 + elif 144 <= s <= 287: + return 1 + elif 288 <= s <= 431: + return 2 + elif 432 <= s <= 575: + return 3 + elif 576 <= s <= 719: + return 4 + elif 720 <= s <= 863: + return 5 + elif 864 <= s <= 1007: + return 6 + elif 1008 <= s <= 1151: + return 7 + elif 1152 <= s <= 1295: + return 8 + elif 1296 <= s <= 1439: + return 9 + else: + return 10 + + +def training_lons(): + return np.array([136.6180, 137.9440, 141.4890, 141.7020, + 142.2410, 142.4220, 143.0040, + 143.3780, 143.4350, 143.8990, + 136.8940]) + +def training_lats(): + return np.array([33.0700, 33.1670, 34.5520, 36.1940, + 37.1470, 40.5910, 38.6600, + 40.1760, 40.8600, 42.0840, + 33.6830]) + + +def unseen_lons(): + return np.array([140.8270, 141.4590, 142.0500, 142.5420, + 143.2280, 144.0600, 142.6190, + 143.4160,135.9050,137.0710, + 138.0250]) + +def unseen_lats(): + return np.array([33.3620, 36.5340, 41.0190, 34.7450, + 39.8690, 40.2320, 37.8120, + 41.4150, 33.1230, 33.1840, + 34.1750]) + + +epi_num = 1 +save_path = os.getcwd() + "/lightning_logs/combined_122_125/recons/" +t0 = 75 +t1 = 150 +t2 = 225 +t0_ver_num = 122 +t1_ver_num = 125 +t2_ver_num = 125 +t0_fname = "unseen_agg_11_sims_0_time_ss_2_ss_unstruct_ntimes_3179_wd_new_epis_v7_tf_145.mat" +t1_fname = "unseen_agg_11_sims_0_time_ss_2_ss_unstruct_ntimes_3179_wd_new_epis_v7_tf_289.mat" +t2_fname = "unseen_agg_11_sims_0_time_ss_2_ss_unstruct_ntimes_3179_wd_new_epis_v7_tf_289.mat" +t0_slice = int(t0*6/5) + (epi_num-1)*145 - 1 +t1_slice = int((t1-120)*6/5) + (epi_num-1)*144 - 1 +t2_slice = int((t2-120)*6/5) + (epi_num-1)*144 - 1 + + +#good for epi 7 +# min_lat_plot = 2.5 #0 +# max_lat_plot = 57.5 # 60 +# min_lon_plot = 120 # 115 +# max_lon_plot = 185 # 175 + +min_lat_plot = 2.5 #0 +max_lat_plot = 60 # 60 +min_lon_plot = 120 # 115 +max_lon_plot = 187.5 # 175 + +dx = (lon_vals[1]-lon_vals[0])/2. +dy = (lat_vals[1]-lat_vals[0])/2. +min_lat_idx = np.argsort(np.abs(lat_vals-min_lat_plot))[0] +max_lat_idx = np.argsort(np.abs(lat_vals-max_lat_plot))[0] +min_lon_idx = np.argsort(np.abs(lon_vals-min_lon_plot))[0] +max_lon_idx = np.argsort(np.abs(lon_vals-max_lon_plot))[0] +extent = [lon_vals[min_lon_idx]-dx, lon_vals[max_lon_idx]+dx, + lat_vals[min_lat_idx]-dy, lat_vals[max_lat_idx]+dy] + + +t0_out_path = os.getcwd()+"/version_"+str(t0_ver_num)+"/" +t0_output_im = torch.load(t0_out_path+'tensor_unseen.pt') +t0_true_data, latitude, longitude, mask, max_ht, times, sensors, div = SWETsunamiForPlotting(t0_fname) +t0_true_data *= max_ht +t0_output_im *= max_ht + +t1_out_path = os.getcwd()+"/version_"+str(t1_ver_num)+"/" +t1_output_im = torch.load(t1_out_path+'tensor_unseen.pt') +t1_true_data, latitude, longitude, mask, max_ht, times, sensors, div = SWETsunamiForPlotting(t1_fname) +t1_true_data *= max_ht +t1_output_im *= max_ht + +t2_out_path = os.getcwd()+"/version_"+str(t2_ver_num)+"/" +t2_output_im = torch.load(t2_out_path+'tensor_unseen.pt') +t2_true_data, latitude, longitude, mask, max_ht, times, sensors, div = SWETsunamiForPlotting(t2_fname) +t2_true_data *= max_ht +t2_output_im *= max_ht + +# NEED INTERPOLATED WAVE_HEIGT, LONGITUDE, LATITUDE, MASK +epi_lons = unseen_lons() +epi_lats = unseen_lats() +tsun_lons_tmp = longitude*(180 / np.pi) +tsun_lats_tmp = latitude*(180 / np.pi) +tsun_lons = tsun_lons_tmp +tsun_lats = tsun_lats_tmp +tsun_lons = tsun_lons[(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot) ] +tsun_lats = tsun_lats[(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot)] +int_Lons, int_Lats = np.meshgrid(lon_vals, lat_vals) +unstruct_coords = np.array([tsun_lons, tsun_lats]).T + +t0_epi_idx = coord_idx(t0_slice,t0_ver_num) +t0_epi_lon = epi_lons[t0_epi_idx] +t0_epi_lat = epi_lats[t0_epi_idx] + +t1_epi_idx = coord_idx(t1_slice,t1_ver_num) +t1_epi_lon = epi_lons[t1_epi_idx] +t1_epi_lat = epi_lats[t1_epi_idx] + +t2_epi_idx = coord_idx(t2_slice,t2_ver_num) +t2_epi_lon = epi_lons[t2_epi_idx] +t2_epi_lat = epi_lats[t2_epi_idx] + + +t0_prediction = t0_output_im[t0_slice].cpu().detach().numpy() +t0_true_data_s = t0_true_data[t0_slice][(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot)] +t0_prediction = t0_prediction[(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot)] +t0_interp_true = np.asarray( + griddata(unstruct_coords, t0_true_data_s, (int_Lons, np.flip(int_Lats)), method='cubic', fill_value=0)) +t0_interp_pred = np.asarray( + griddata(unstruct_coords, t0_prediction, (int_Lons, np.flip(int_Lats)), method='cubic', fill_value=0)) +t0_interp_true = t0_interp_true[:,:,0] +t0_interp_pred = t0_interp_pred[:,:,0] +t0_interp_abs_err = np.abs(t0_interp_true-t0_interp_pred) +t0_interp_true[(topo_vals != 0)] = 0.0 +t0_interp_pred[(topo_vals != 0)] = 0.0 +t0_interp_abs_err[(topo_vals != 0)] = 0.0 + +t1_prediction = t1_output_im[t1_slice].cpu().detach().numpy() +t1_true_data_s = t1_true_data[t1_slice][(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot)] +t1_prediction = t1_prediction[(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot)] +t1_interp_true = np.asarray( + griddata(unstruct_coords, t1_true_data_s, (int_Lons, np.flip(int_Lats)), method='cubic', fill_value=0)) +t1_interp_pred = np.asarray( + griddata(unstruct_coords, t1_prediction, (int_Lons, np.flip(int_Lats)), method='cubic', fill_value=0)) +t1_interp_true = t1_interp_true[:,:,0] +t1_interp_pred = t1_interp_pred[:,:,0] +t1_interp_abs_err = np.abs(t1_interp_true-t1_interp_pred) +t1_interp_true[(topo_vals != 0)] = 0.0 +t1_interp_pred[(topo_vals != 0)] = 0.0 +t1_interp_abs_err[(topo_vals != 0)] = 0.0 + + +t2_prediction = t2_output_im[t2_slice].cpu().detach().numpy() +t2_true_data_s = t2_true_data[t2_slice][(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot)] +t2_prediction = t2_prediction[(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot)] +t2_interp_true = np.asarray( + griddata(unstruct_coords, t2_true_data_s, (int_Lons, np.flip(int_Lats)), method='cubic', fill_value=0)) +t2_interp_pred = np.asarray( + griddata(unstruct_coords, t2_prediction, (int_Lons, np.flip(int_Lats)), method='cubic', fill_value=0)) +t2_interp_true = t2_interp_true[:,:,0] +t2_interp_pred = t2_interp_pred[:,:,0] +t2_interp_abs_err = np.abs(t2_interp_true-t2_interp_pred) +t2_interp_true[(topo_vals != 0)] = 0.0 +t2_interp_pred[(topo_vals != 0)] = 0.0 +t2_interp_abs_err[(topo_vals != 0)] = 0.0 + + +t0_min_val = min([np.min(t0_interp_true),np.min(t0_interp_pred)]) +t0_max_val = max([np.max(t0_interp_true),np.max(t0_interp_pred)]) +t1_min_val = min([np.min(t1_interp_true),np.min(t1_interp_pred)]) +t1_max_val = max([np.max(t1_interp_true),np.max(t1_interp_pred)]) +t2_min_val = min([np.min(t2_interp_true),np.min(t2_interp_pred)]) +t2_max_val = max([np.max(t2_interp_true),np.max(t2_interp_pred)]) + +cmin_val = min([t0_min_val,t1_min_val,t2_min_val]) +cmax_val = max([t0_max_val,t1_max_val,t2_max_val]) +emin_val = min([np.min(t0_interp_abs_err),np.min(t1_interp_abs_err),np.min(t2_interp_abs_err)]) +emax_val = max([np.max(t0_interp_abs_err),np.max(t1_interp_abs_err),np.max(t2_interp_abs_err)]) + + +width = (extent[1]-extent[0]) +height = (extent[3]-extent[2]) +aspect = width/height + +fig, axs = plt.subplots(nrows=3,ncols=3,figsize=(16.75,16.25),constrained_layout=True,gridspec_kw={'wspace': 0.01,'hspace': 0.005}) +from numpy.ma import masked_array +topo_img = masked_array(topo_vals, topo_vals == 0.0) +lons_topo_img = masked_array(int_Lons, topo_vals == 0.0) +lats_topo_img = masked_array(int_Lats, topo_vals == 0.0) +lons_tsu_img = masked_array(int_Lons, topo_vals != 0.0) +lats_tsu_img = masked_array(int_Lats, topo_vals != 0.0) + +t0_true_tsu_img = masked_array(t0_interp_true, topo_vals != 0.0) +t0_pred_tsu_img = masked_array(t0_interp_pred, topo_vals != 0.0) +t0_abs_err_img = masked_array(t0_interp_abs_err, topo_vals != 0.0) + + +t0_shw1 = axs[0,0].contourf(lons_tsu_img,np.flip(lats_tsu_img),t0_true_tsu_img,levels=50,extent=extent,cmap='ocean') +t0_shw2 = axs[0,0].contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +t0_shw3 = axs[0,1].contourf(lons_tsu_img,np.flip(lats_tsu_img),t0_pred_tsu_img,levels=50,extent=extent,cmap='ocean') +t0_shw4 = axs[0,1].contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +t0_shw5 = axs[0,2].contourf(lons_tsu_img,np.flip(lats_tsu_img),t0_abs_err_img,levels=50,extent=extent,cmap='Blues') +t0_shw6 = axs[0,2].contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +t0_shw1.set_clim(cmin_val,cmax_val) +t0_shw3.set_clim(cmin_val,cmax_val) + +t1_true_tsu_img = masked_array(t1_interp_true, topo_vals != 0.0) +t1_pred_tsu_img = masked_array(t1_interp_pred, topo_vals != 0.0) +t1_abs_err_img = masked_array(t1_interp_abs_err, topo_vals != 0.0) + +t1_shw1 = axs[1,0].contourf(lons_tsu_img,np.flip(lats_tsu_img),t1_true_tsu_img,levels=50,extent=extent,cmap='ocean') +t1_shw2 = axs[1,0].contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +t1_shw3 = axs[1,1].contourf(lons_tsu_img,np.flip(lats_tsu_img),t1_pred_tsu_img,levels=50,extent=extent,cmap='ocean') +t1_shw4 = axs[1,1].contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +t1_shw5 = axs[1,2].contourf(lons_tsu_img,np.flip(lats_tsu_img),t1_abs_err_img,levels=50,extent=extent,cmap='Blues') + +t1_shw6 = axs[1,2].contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +t1_shw1.set_clim(cmin_val,cmax_val) +t1_shw3.set_clim(cmin_val,cmax_val) + +t2_true_tsu_img = masked_array(t2_interp_true, topo_vals != 0.0) +t2_pred_tsu_img = masked_array(t2_interp_pred, topo_vals != 0.0) +t2_abs_err_img = masked_array(t2_interp_abs_err, topo_vals != 0.0) + +t2_shw1 = axs[2,0].contourf(lons_tsu_img,np.flip(lats_tsu_img),t2_true_tsu_img,levels=50,extent=extent,cmap='ocean') +t2_shw2 = axs[2,0].contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +t2_shw3 = axs[2,1].contourf(lons_tsu_img,np.flip(lats_tsu_img),t2_pred_tsu_img,levels=50,extent=extent,cmap='ocean') +t2_shw4 = axs[2,1].contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +t2_shw5 = axs[2,2].contourf(lons_tsu_img,np.flip(lats_tsu_img),t2_abs_err_img,levels=50,extent=extent,cmap='Blues') + +t2_shw6 = axs[2,2].contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +t2_shw1.set_clim(cmin_val,cmax_val) +t2_shw3.set_clim(cmin_val,cmax_val) + + +import shapely.ops as sops +new_shape = sops.unary_union([el for el in chi['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0,0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0,1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0,2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1,0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1,1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1,2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2,0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2,1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2,2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in rus['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in jap['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in nko['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in sko['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in phi['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in tai['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 2].plot(*geom.exterior.xy, c='k', lw=1.0) +new_shape = sops.unary_union([el for el in vie['geometry'].values[0].geoms]) +for geom in new_shape.geoms: + axs[0, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[0, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[1, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 0].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 1].plot(*geom.exterior.xy, c='k', lw=1.0) + axs[2, 2].plot(*geom.exterior.xy, c='k', lw=1.0) + +sensor_longs = sensors[:, 0] +sensor_lats = sensors[:, 1] + +axs[0,0].scatter(sensor_longs * (180 / np.pi), sensor_lats * (180 / np.pi), s=200, marker="^",color='yellow',edgecolor='k') +axs[0,0].scatter(t0_epi_lon,t0_epi_lat,s=800,marker="x",color='r') +axs[0,0].scatter((sensor_longs * (180 / np.pi))[14],(sensor_lats * (180 / np.pi))[14],s=800,facecolors='none', edgecolors='k', linewidth=5) +axs[1,0].scatter(sensor_longs * (180 / np.pi), sensor_lats * (180 / np.pi), s=200, marker="^",color='yellow',edgecolor='k') +axs[1,0].scatter(t1_epi_lon,t1_epi_lat,s=800,marker="x",color='r') +axs[1,0].scatter((sensor_longs * (180 / np.pi))[14],(sensor_lats * (180 / np.pi))[14],s=800,facecolors='none', edgecolors='k', linewidth=5) +axs[2,0].scatter(sensor_longs * (180 / np.pi), sensor_lats * (180 / np.pi), s=200, marker="^",color='yellow',edgecolor='k') +axs[2,0].scatter(t2_epi_lon,t2_epi_lat,s=800,marker="x",color='r') +axs[2,0].scatter((sensor_longs * (180 / np.pi))[14],(sensor_lats * (180 / np.pi))[14],s=800,facecolors='none', edgecolors='k', linewidth=5) +axs[0,0].scatter((sensor_longs * (180 / np.pi))[23],(sensor_lats * (180 / np.pi))[23],s=800,facecolors='none', edgecolors='k', linewidth=5) +axs[1,0].scatter((sensor_longs * (180 / np.pi))[23],(sensor_lats * (180 / np.pi))[23],s=800,facecolors='none', edgecolors='k', linewidth=5) +axs[2,0].scatter((sensor_longs * (180 / np.pi))[23],(sensor_lats * (180 / np.pi))[23],s=800,facecolors='none', edgecolors='k', linewidth=5) + + +print("circled sensor 1:",(sensor_longs * (180 / np.pi))[14],(sensor_lats * (180 / np.pi))[14]) +print("circled sensor 2:",(sensor_longs * (180 / np.pi))[23],(sensor_longs * (180 / np.pi))[23]) + +axs[0,1].scatter(sensor_longs * (180 / np.pi), sensor_lats * (180 / np.pi), s=200, marker="^",color='yellow',edgecolor='k') +axs[0,1].scatter(t0_epi_lon,t0_epi_lat,s=800,marker="x",color='r') +axs[0,1].scatter((sensor_longs * (180 / np.pi))[14],(sensor_lats * (180 / np.pi))[14],s=800,facecolors='none', edgecolors='k', linewidth=5) +axs[1,1].scatter(sensor_longs * (180 / np.pi), sensor_lats * (180 / np.pi), s=200, marker="^",color='yellow',edgecolor='k') +axs[1,1].scatter(t1_epi_lon,t1_epi_lat,s=800,marker="x",color='r') +axs[1,1].scatter((sensor_longs * (180 / np.pi))[14],(sensor_lats * (180 / np.pi))[14],s=800,facecolors='none', edgecolors='k', linewidth=5) +axs[2,1].scatter(sensor_longs * (180 / np.pi), sensor_lats * (180 / np.pi), s=200, marker="^",color='yellow',edgecolor='k') +axs[2,1].scatter(t2_epi_lon,t2_epi_lat,s=800,marker="x",color='r') +axs[2,1].scatter((sensor_longs * (180 / np.pi))[14],(sensor_lats * (180 / np.pi))[14],s=800,facecolors='none', edgecolors='k', linewidth=5) +axs[0,1].scatter((sensor_longs * (180 / np.pi))[23],(sensor_lats * (180 / np.pi))[23],s=800,facecolors='none', edgecolors='k', linewidth=5) +axs[1,1].scatter((sensor_longs * (180 / np.pi))[23],(sensor_lats * (180 / np.pi))[23],s=800,facecolors='none', edgecolors='k', linewidth=5) +axs[2,1].scatter((sensor_longs * (180 / np.pi))[23],(sensor_lats * (180 / np.pi))[23],s=800,facecolors='none', edgecolors='k', linewidth=5) + +axs[0,0].set_xlim([extent[0], extent[1]]) +axs[0,0].set_ylim([extent[2],extent[3]]) +axs[1,0].set_xlim([extent[0], extent[1]]) +axs[1,0].set_ylim([extent[2],extent[3]]) +axs[2,0].set_xlim([extent[0], extent[1]]) +axs[2,0].set_ylim([extent[2],extent[3]]) +axs[0,1].set_xlim([extent[0], extent[1]]) +axs[0,1].set_ylim([extent[2],extent[3]]) +axs[1,1].set_xlim([extent[0], extent[1]]) +axs[1,1].set_ylim([extent[2],extent[3]]) +axs[2,1].set_xlim([extent[0], extent[1]]) +axs[2,1].set_ylim([extent[2],extent[3]]) +axs[0,2].set_xlim([extent[0], extent[1]]) +axs[0,2].set_ylim([extent[2],extent[3]]) +axs[1,2].set_xlim([extent[0], extent[1]]) +axs[1,2].set_ylim([extent[2],extent[3]]) +axs[2,2].set_xlim([extent[0], extent[1]]) +axs[2,2].set_ylim([extent[2],extent[3]]) +axs[0,0].set_title(r'\textbf{True}') +axs[0,1].set_title(r'\textbf{Predicted}') +axs[0,2].set_title(r'\textbf{Absolute Error}') +axs[0,0].set_aspect(aspect) +axs[0,1].set_aspect(aspect) +axs[0,2].set_aspect(aspect) +axs[1,0].set_aspect(aspect) +axs[1,1].set_aspect(aspect) +axs[1,2].set_aspect(aspect) +axs[2,0].set_aspect(aspect) +axs[2,1].set_aspect(aspect) +axs[2,2].set_aspect(aspect) + + +fig.canvas.draw() +thickness = 0.025 # colorbar thickness in figure coordinates +gap = 0.01 # gap between colorbars in figure coordinates +bottom_offset = 0.06 # vertical offset below subplots (adjust as needed) + +# Get positions of bottom row axes +pos0 = axs[2, 0].get_position() +pos1 = axs[2, 1].get_position() +pos2 = axs[2, 2].get_position() + +# Colorbar 1: under columns 0 and 1 +cbar1_x0 = pos0.x0 +cbar1_x1 = pos1.x1 - gap/2 +cbar1_width = cbar1_x1 - cbar1_x0 + +# Colorbar 2: under column 2 +cbar2_x0 = pos2.x0 + gap/2 +cbar2_x1 = pos2.x1 +cbar2_width = cbar2_x1 - cbar2_x0 + +# Vertical position for both colorbars +cbar_y = pos0.y0 - bottom_offset + +# Create axes for colorbars +cax1 = fig.add_axes([cbar1_x0, cbar_y, cbar1_width, thickness]) +cax2 = fig.add_axes([cbar2_x0, cbar_y, cbar2_width, thickness]) + +# Create colorbars and set labels +cbar1 = fig.colorbar(t0_shw1, cax=cax1, orientation='horizontal') +cbar1.set_label(r'\textbf{Wave Height (' + r'$\mathbf{m}$' + r'\textbf{)}') + +cbar2 = fig.colorbar(t0_shw5, cax=cax2, orientation='horizontal',format="%.2f") +cbar2.set_label(r'\textbf{Error}') +############## + +lon_ticks = np.linspace(extent[0]+2,extent[1]-2,5) + +lat_ticks = np.linspace(extent[2]+2,extent[3]-2,5) + +lon_labels = [] +for x in lon_ticks: + if x < 0: + txt = r'\textbf{' + str(abs(np.round(x, 1))) + '}' + r'\textbf{$^\circ$W}' + lon_labels.append(txt) + elif x > 0: + txt = r'\textbf{' + str(abs(np.round(x, 1))) + '}' + r'\textbf{$^\circ$E}' + lon_labels.append(txt) + else: + lon_labels.append(r"$0^\circ$") + +lat_labels = [] +for y in lat_ticks: + if y < 0: + txt = r'\textbf{' + str(abs(np.round(y, 1))) + '}' + r'\textbf{$^\circ$S}' + lat_labels.append(txt) + elif y > 0: + txt = r'\textbf{' + str(abs(np.round(y, 1))) + '}' + r'\textbf{$^\circ$N}' + lat_labels.append(txt) + else: + lat_labels.append(r"$0^\circ$") + + +axs[0,0].set_yticks(lat_ticks) +axs[0,0].set_yticklabels(lat_labels) +axs[1,0].set_yticks(lat_ticks) +axs[1,0].set_yticklabels(lat_labels) +axs[2,0].set_yticks(lat_ticks) +axs[2,0].set_yticklabels(lat_labels) +axs[2,0].set_xticks(lon_ticks) +axs[2,0].set_xticklabels(lon_labels,rotation=90) +axs[2,1].set_xticks(lon_ticks) +axs[2,1].set_xticklabels(lon_labels,rotation=90) +axs[2,2].set_xticks(lon_ticks) +axs[2,2].set_xticklabels(lon_labels,rotation=90) + + +for ax in fig.get_axes(): + for label in ax.get_xticklabels() + ax.get_yticklabels(): + label.set_fontweight('bold') + + +nrows, ncols = axs.shape # assuming axs is a 2D array + + +# Optional: adjust tick label font size or rotation +ax.tick_params(axis='x', labelsize=10) +ax.tick_params(axis='y', labelsize=10) + +for ax in axs.flat: + ax.set_xticks(lon_ticks) + ax.set_yticks(lat_ticks) + ax.grid(color='k', linestyle='--', linewidth=0.5) + +for i in range(nrows): + for j in range(ncols): + ax = axs[i, j] + # Remove y tick labels for all but first column + if j > 0: + ax.set_yticklabels([]) + # Optionally, also remove minor ticks + ax.tick_params(labelleft=False) + # Remove x tick labels for all but last row + if i < nrows - 1: + ax.set_xticklabels([]) + ax.tick_params(labelbottom=False) + +for label in cbar2.ax.get_xticklabels(): + label.set_fontsize(16) +##################################################################### + +plt.savefig(save_path+"h_recon_epi_{}_t0_{}_new_v3.png".format(epi_num,75),dpi=600,bbox_inches='tight') + + + diff --git a/tsunami/make_grl_fig_4.py b/tsunami/make_grl_fig_4.py new file mode 100644 index 0000000..ae00d8d --- /dev/null +++ b/tsunami/make_grl_fig_4.py @@ -0,0 +1,822 @@ +import multiprocessing +multiprocessing.set_start_method("fork") +import torch +import matplotlib.pyplot as plt +import numpy as np +from datasets import SWETsunamiForPlotting3 +from scipy.interpolate import LinearNDInterpolator +from scipy.signal import medfilt +import scipy.io as sio +import sys +import os + +plt.rcParams['text.usetex'] = True +plt.rcParams["font.family"] = "serif" +plt.rcParams["font.size"] = 28 +plt.rcParams['xtick.labelsize'] = 28 +plt.rcParams['ytick.labelsize'] = 28 +path = os.getcwd()+"/virtual_waveforms/" + + +### EPICENTERS OF UNSEEN DATA ### +def unseen_lons(): + return np.array([140.8270, 141.4590, 142.0500, 142.5420, + 143.2280, 144.0600, 142.6190, + 143.4160,135.9050,137.0710, + 138.0250]) + +def unseen_lats(): + return np.array([33.3620, 36.5340, 41.0190, 34.7450, + 39.8690, 40.2320, 37.8120, + 41.4150, 33.1230, 33.1840, + 34.1750]) + + +epi_lons = unseen_lons() +epi_lats = unseen_lats() + + +ver_num_2hr = 122 +ver_num_4hr = 125 +fig_dir = "/figs_122_125" +log_dir = "/logs_122_125" +fname_2hr = "unseen_agg_11_sims_0_time_ss_2_ss_unstruct_ntimes_3179_wd_new_epis_v7_tf_145.mat" +fname_4hr = "unseen_agg_11_sims_0_time_ss_2_ss_unstruct_ntimes_3179_wd_new_epis_v7_tf_289.mat" + + + +logfile = open(path+log_dir+"/metrics_v2.out", 'w') +sys.stdout = logfile + + +type = "unseen" +out_path_2hr = os.getcwd()+"/lightning_logs_0706/version_"+str(ver_num_2hr)+"/" +output_im_2hr = torch.load(out_path_2hr+'tensor'+'_'+str(type)+'.pt').numpy() +out_path_4hr = os.getcwd()+"/lightning_logs_0706/version_"+str(ver_num_4hr)+"/" +output_im_4hr = torch.load(out_path_4hr+'tensor'+'_'+str(type)+'.pt').numpy() +true_data_2hr, latitude_2hr, longitude_2hr, mask_2hr, max_ht_2hr, times_2hr, sensors_2hr, div_2hr = SWETsunamiForPlotting3(fname_2hr) +true_data_2hr *= max_ht_2hr +output_im_2hr *= max_ht_2hr +true_data_4hr, latitude_4hr, longitude_4hr, mask_4hr, max_ht_4hr, times_4hr, sensors_4hr, div_4hr = SWETsunamiForPlotting3(fname_4hr) +true_data_4hr *= max_ht_4hr +output_im_4hr *= max_ht_4hr +tsun_lons = longitude_4hr*(180 / np.pi) +tsun_lats = latitude_4hr*(180 / np.pi) + + + + +### COLLECT SIMULATION DATA INTO ONE FOUR HOUR SET ### +intervals_2hr = [(0, 145), (145, 290), (290, 435), (435, 580), (580, 725), (725, 870), (870, 1015), + (1015, 1160), (1160, 1305), (1305, 1450), (1450, 1595)] +sim1_2hr = true_data_2hr[intervals_2hr[0][0]:intervals_2hr[0][1]] +sim2_2hr = true_data_2hr[intervals_2hr[1][0]:intervals_2hr[1][1]] +sim3_2hr = true_data_2hr[intervals_2hr[2][0]:intervals_2hr[2][1]] +sim4_2hr = true_data_2hr[intervals_2hr[3][0]:intervals_2hr[3][1]] +sim5_2hr = true_data_2hr[intervals_2hr[4][0]:intervals_2hr[4][1]] +sim6_2hr = true_data_2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] +sim7_2hr = true_data_2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] +sim8_2hr = true_data_2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] +sim9_2hr = true_data_2hr[intervals_2hr[8][0]:intervals_2hr[8][1]] +sim10_2hr = true_data_2hr[intervals_2hr[9][0]:intervals_2hr[9][1]] +sim11_2hr = true_data_2hr[intervals_2hr[10][0]:intervals_2hr[10][1]] +intervals_4hr = [(0, 144), (144, 288), (288, 432), (432, 576), (576, 720), (720, 864), (864, 1008), + (1008, 1152), (1152, 1296), (1296, 1440), (1440, 1584)] +sim1_4hr = true_data_4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] +sim2_4hr = true_data_4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] +sim3_4hr = true_data_4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] +sim4_4hr = true_data_4hr[intervals_4hr[3][0]:intervals_4hr[3][1]] +sim5_4hr = true_data_4hr[intervals_4hr[4][0]:intervals_4hr[4][1]] +sim6_4hr = true_data_4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] +sim7_4hr = true_data_4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] +sim8_4hr = true_data_4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] +sim9_4hr = true_data_4hr[intervals_4hr[8][0]:intervals_4hr[8][1]] +sim10_4hr = true_data_4hr[intervals_4hr[9][0]:intervals_4hr[9][1]] +sim11_4hr = true_data_4hr[intervals_4hr[10][0]:intervals_4hr[10][1]] +sim1_full = np.concatenate((sim1_2hr,sim1_4hr),axis=0) +sim2_full = np.concatenate((sim2_2hr,sim2_4hr),axis=0) +sim3_full = np.concatenate((sim3_2hr,sim3_4hr),axis=0) +sim4_full = np.concatenate((sim4_2hr,sim4_4hr),axis=0) +sim5_full = np.concatenate((sim5_2hr,sim5_4hr),axis=0) +sim6_full = np.concatenate((sim6_2hr,sim6_4hr),axis=0) +sim7_full = np.concatenate((sim7_2hr,sim7_4hr),axis=0) +sim8_full = np.concatenate((sim8_2hr,sim8_4hr),axis=0) +sim9_full = np.concatenate((sim9_2hr,sim9_4hr),axis=0) +sim10_full = np.concatenate((sim10_2hr,sim10_4hr),axis=0) +sim11_full = np.concatenate((sim11_2hr,sim11_4hr),axis=0) + +### COLLECT SENSEIVER RECONS INTO ONE FOUR HOUR SET ### +sens1_2hr = output_im_2hr[intervals_2hr[0][0]:intervals_2hr[0][1]] +sens2_2hr = output_im_2hr[intervals_2hr[1][0]:intervals_2hr[1][1]] +sens3_2hr = output_im_2hr[intervals_2hr[2][0]:intervals_2hr[2][1]] +sens4_2hr = output_im_2hr[intervals_2hr[3][0]:intervals_2hr[3][1]] +sens5_2hr = output_im_2hr[intervals_2hr[4][0]:intervals_2hr[4][1]] +sens6_2hr = output_im_2hr[intervals_2hr[5][0]:intervals_2hr[5][1]] +sens7_2hr = output_im_2hr[intervals_2hr[6][0]:intervals_2hr[6][1]] +sens8_2hr = output_im_2hr[intervals_2hr[7][0]:intervals_2hr[7][1]] +sens9_2hr = output_im_2hr[intervals_2hr[8][0]:intervals_2hr[8][1]] +sens10_2hr = output_im_2hr[intervals_2hr[9][0]:intervals_2hr[9][1]] +sens11_2hr = output_im_2hr[intervals_2hr[10][0]:intervals_2hr[10][1]] +sens1_4hr = output_im_4hr[intervals_4hr[0][0]:intervals_4hr[0][1]] +sens2_4hr = output_im_4hr[intervals_4hr[1][0]:intervals_4hr[1][1]] +sens3_4hr = output_im_4hr[intervals_4hr[2][0]:intervals_4hr[2][1]] +sens4_4hr = output_im_4hr[intervals_4hr[3][0]:intervals_4hr[3][1]] +sens5_4hr = output_im_4hr[intervals_4hr[4][0]:intervals_4hr[4][1]] +sens6_4hr = output_im_4hr[intervals_4hr[5][0]:intervals_4hr[5][1]] +sens7_4hr = output_im_4hr[intervals_4hr[6][0]:intervals_4hr[6][1]] +sens8_4hr = output_im_4hr[intervals_4hr[7][0]:intervals_4hr[7][1]] +sens9_4hr = output_im_4hr[intervals_4hr[8][0]:intervals_4hr[8][1]] +sens10_4hr = output_im_4hr[intervals_4hr[9][0]:intervals_4hr[9][1]] +sens11_4hr = output_im_4hr[intervals_4hr[10][0]:intervals_4hr[10][1]] +sens1_full = np.concatenate((sens1_2hr,sens1_4hr),axis=0) +sens2_full = np.concatenate((sens2_2hr,sens2_4hr),axis=0) +sens3_full = np.concatenate((sens3_2hr,sens3_4hr),axis=0) +sens4_full = np.concatenate((sens4_2hr,sens4_4hr),axis=0) +sens5_full = np.concatenate((sens5_2hr,sens5_4hr),axis=0) +sens6_full = np.concatenate((sens6_2hr,sens6_4hr),axis=0) +sens7_full = np.concatenate((sens7_2hr,sens7_4hr),axis=0) +sens8_full = np.concatenate((sens8_2hr,sens8_4hr),axis=0) +sens9_full = np.concatenate((sens9_2hr,sens9_4hr),axis=0) +sens10_full = np.concatenate((sens10_2hr,sens10_4hr),axis=0) +sens11_full = np.concatenate((sens11_2hr,sens11_4hr),axis=0) + + +### SENSOR INDICES, LOCATIONS AND BATHYMETRY ### +sensors = sensors_2hr*(180/np.pi) +mat_data = sio.loadmat(os.getcwd()+"/Data/tsunami/"+fname_2hr) +sensor_indices = mat_data['sensor_loc_indices'] +sens_lons = tsun_lons[sensor_indices[0]] +sens_lats = tsun_lats[sensor_indices[0]] +bathymetry = mat_data['ocn_floor'] +xy = np.c_[tsun_lons, tsun_lats] +bath = LinearNDInterpolator(xy, bathymetry[0]) + +### RESTRICT DATA TO SENSIBLE WINDOW ### +min_lat_plot = 10 +max_lat_plot = 45 +min_lon_plot = 125 +max_lon_plot = 160 +in_window_indicator = np.zeros_like(sens_lons) +for i in range(len(sens_lons)): + if min_lon_plot <= sens_lons[i] <= max_lon_plot and min_lat_plot <= sens_lats[i] <= max_lat_plot: + in_window_indicator[i] += 1 +in_window_sensor_indices = sensor_indices[0][np.where(in_window_indicator==1.0)] +sens_lons_inner = tsun_lons[in_window_sensor_indices] +sens_lats_inner = tsun_lats[in_window_sensor_indices] +sim1_sens_vals_inner = sim1_full[:,in_window_sensor_indices][:,:,0].T +sim2_sens_vals_inner = sim2_full[:,in_window_sensor_indices][:,:,0].T +sim3_sens_vals_inner = sim3_full[:,in_window_sensor_indices][:,:,0].T +sim4_sens_vals_inner = sim4_full[:,in_window_sensor_indices][:,:,0].T +sim5_sens_vals_inner = sim5_full[:,in_window_sensor_indices][:,:,0].T +sim6_sens_vals_inner = sim6_full[:,in_window_sensor_indices][:,:,0].T +sim7_sens_vals_inner = sim7_full[:,in_window_sensor_indices][:,:,0].T +sim8_sens_vals_inner = sim8_full[:,in_window_sensor_indices][:,:,0].T +sim9_sens_vals_inner = sim9_full[:,in_window_sensor_indices][:,:,0].T +sim10_sens_vals_inner = sim10_full[:,in_window_sensor_indices][:,:,0].T +sim11_sens_vals_inner = sim11_full[:,in_window_sensor_indices][:,:,0].T + + +### MAKE A MATRIX CONSISTING OF DISTANCES BETWEEN PAIRS OF SENSORS ### +sens_distance_arr = np.zeros(shape=(len(sens_lons_inner),len(sens_lons_inner))) +for i in range(len(sens_lons_inner)): + for j in range(len(sens_lons_inner)): + if i= min_lat_plot)[0][-1] +max_lat_idx = np.where(lat_vals >= max_lat_plot)[0][-1] +min_lon_idx = np.where(lon_vals >= min_lon_plot)[0][0] +max_lon_idx = np.where(lon_vals >= max_lon_plot)[0][0] +nlat_plot, nlon_plot = np.shape(topo_data[0][max_lat_idx:min_lat_idx,min_lon_idx:max_lon_idx]) +lat_vals_new = np.linspace(min_lat_plot,max_lat_plot,nlat_plot) +lon_vals_new = np.linspace(min_lon_plot,max_lon_plot,nlon_plot) +dx = (lon_vals_new[1]-lon_vals_new[0])/2. +dy = (lat_vals_new[1]-lat_vals_new[0])/2. +extent = [lon_vals_new[0]-dx, lon_vals_new[-1]+dx, lat_vals_new[0]-dy, lat_vals_new[-1]+dy] +topo_vals = topo_data[0][max_lat_idx:min_lat_idx,min_lon_idx:max_lon_idx] +topo_vals = topo_vals[::80,::80] +lat_vals = lat_vals_new[::80] +lon_vals = lon_vals_new[::80] +dx = (lon_vals[1]-lon_vals[0])/2. +dy = (lat_vals[1]-lat_vals[0])/2. +min_lat_idx = np.argsort(np.abs(lat_vals-min_lat_plot))[0] +max_lat_idx = np.argsort(np.abs(lat_vals-max_lat_plot))[0] +min_lon_idx = np.argsort(np.abs(lon_vals-min_lon_plot))[0] +max_lon_idx = np.argsort(np.abs(lon_vals-max_lon_plot))[0] +extent = [lon_vals[min_lon_idx]-dx, lon_vals[max_lon_idx]+dx, + lat_vals[min_lat_idx]-dy, lat_vals[max_lat_idx]+dy] +ocn_floor = bathymetry +tsun_lons_tmp = tsun_lons +tsun_lats_tmp = tsun_lats +tsun_lons_for_bath = tsun_lons_tmp +tsun_lats_for_bath = tsun_lats_tmp +tsun_lons_for_bath = tsun_lons[(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot) ] +tsun_lats_for_bath = tsun_lats[(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot)] +int_Lons, int_Lats = np.meshgrid(lon_vals, lat_vals) +ocn_floor_s = ocn_floor[0][(tsun_lons_tmp >= min_lon_plot) & + (tsun_lons_tmp <= max_lon_plot) & + (tsun_lats_tmp >= min_lat_plot) & + (tsun_lats_tmp <= max_lat_plot)] +unstruct_coords = np.array([tsun_lons_for_bath, tsun_lats_for_bath]).T +interp_true = np.asarray( + griddata(unstruct_coords, ocn_floor_s, (int_Lons, np.flip(int_Lats)), method='cubic', fill_value=0)) +interp_true[(topo_vals != 0)] = 0.0 +width = (extent[1]-extent[0]) +height = (extent[3]-extent[2]) +aspect = width/height +fig, ax = plt.subplots(nrows=1,ncols=1,figsize=(10,6.15)) +from numpy.ma import masked_array +topo_img = masked_array(topo_vals, topo_vals == 0.0) +lons_topo_img = masked_array(int_Lons, topo_vals == 0.0) +lats_topo_img = masked_array(int_Lats, topo_vals == 0.0) +true_tsu_img = masked_array(interp_true, topo_vals != 0.0) +lons_tsu_img = masked_array(int_Lons, topo_vals != 0.0) +lats_tsu_img = masked_array(int_Lats, topo_vals != 0.0) +shw1 = ax.contourf(lons_tsu_img,np.flip(lats_tsu_img),true_tsu_img,levels=50,extent=extent,cmap='gist_earth') +shw2 = ax.contourf(lons_topo_img,np.flip(lats_topo_img),topo_img,levels=50,extent=extent,cmap='summer') +ax.scatter(real_lons,real_lats,s=480,c='w',marker='o', label=r'\textbf{DART}') +ax.scatter(virtual_lons,virtual_lats,s=480,c='k',marker='x', label=r'\textbf{Virtual}') +ax.text(virtual_lons[0]+1,virtual_lats[0]+1,r'$\mathbf{1}$',color='k',fontsize=32) +ax.text(virtual_lons[1]+1,virtual_lats[1]+1,r'$\mathbf{2}$',color='k',fontsize=32) +ax.text(virtual_lons[2]+1,virtual_lats[2]+1,r'$\mathbf{3}$',color='k',fontsize=32) +ax.text(virtual_lons[3]+1,virtual_lats[3]+1,r'$\mathbf{4}$',color='k',fontsize=32) +ax.text(virtual_lons[4]+1,virtual_lats[4]+1,r'$\mathbf{5}$',color='k',fontsize=32) +ax.text(virtual_lons[5]+1,virtual_lats[5]+1,r'$\mathbf{6}$',color='k',fontsize=32) + +ax.legend(loc='upper left') +ax.set_xlim([extent[0], extent[1]]) +ax.set_ylim([extent[2],extent[3]]) +ax.set_ylabel(r'\textbf{Latitude}', labelpad=10) +ax.set_xlabel(r'\textbf{Longitude}', labelpad=10) +cbar = fig.colorbar(shw1,ax=ax,fraction=.047, pad=0.04) +cbar.set_label(r'\textbf{Bathymetry (m)}') +ax.grid(color = 'gray', linestyle = '--', linewidth = 0.5) +ticks = cbar.get_ticks() +ticklabels = [r'\textbf{%g}' % t for t in ticks] +cbar.set_ticks(ticks) +cbar.set_ticklabels(ticklabels) +bath_save_dir = os.getcwd()+"/GRL_figs_0706/virtual_waveforms/" +plt.savefig(bath_save_dir+"bath.png",bbox_inches='tight',dpi=400) +##################################################################### + + +min_lat_plot = 10 +max_lat_plot = 45 +min_lon_plot = 125 +max_lon_plot = 160 + +### GET LEFT, RIGHT, AND VIRTUAL BATHYMETRY VALUES ### +real_sens_left_bath = np.zeros(6) +for i in range(6): + bath_lon = tsun_lons[in_window_sensor_indices[five_smallest_indices[0][i]]] + bath_lat = tsun_lats[in_window_sensor_indices[five_smallest_indices[0][i]]] + bathpt = bath(bath_lon,bath_lat) + real_sens_left_bath[i] = bathpt + +real_sens_right_bath = np.zeros(6) +for i in range(6): + bath_lon = tsun_lons[in_window_sensor_indices[five_smallest_indices[1][i]]] + bath_lat = tsun_lats[in_window_sensor_indices[five_smallest_indices[1][i]]] + bathpt = bath(bath_lon,bath_lat) + real_sens_right_bath[i] = bathpt + + +virtual_sens_bath = np.zeros(6) +for i in range(6): + vlon = virtual_lons[i] + vlat = virtual_lats[i] + bathpt = bath(vlon, vlat) + virtual_sens_bath[i] = bathpt + +print("Real Sensor Pair Locations: ",sensor_pair_locs) +print("Virtual Sensor Locations: ",virtual_sensor_locs) +print("Left Sens Bathymetry: ",real_sens_left_bath) +print("Right Sens Bathymetry: ",real_sens_right_bath) +print("Virtual Sens Bathymetry: ",virtual_sens_bath) + +############################## +### GET THE REAL WAVEFORMS ### +all_inner_sens_vals = [sim1_sens_vals_inner,sim2_sens_vals_inner,sim3_sens_vals_inner,sim4_sens_vals_inner, + sim5_sens_vals_inner,sim6_sens_vals_inner,sim7_sens_vals_inner,sim8_sens_vals_inner, + sim9_sens_vals_inner, sim10_sens_vals_inner, sim11_sens_vals_inner] + +all_true_for_interp = [sim1_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim2_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim3_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim4_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim5_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim6_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim7_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim8_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim9_full[:, (tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim10_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sim11_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)] + ] + +all_senseiver_for_interp = [sens1_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens2_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens3_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens4_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens5_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens6_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens7_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens8_full[:,(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)], + sens9_full[:, (tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)] + , + sens10_full[:, (tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)] + , + sens11_full[:, (tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)] + ] + +interp_lons = tsun_lons[(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)] + +interp_lats = tsun_lats[(tsun_lons >= min_lon_plot) & (tsun_lons <= max_lon_plot) & + (tsun_lats >= min_lat_plot) & (tsun_lats <= max_lat_plot)] + +xy_inner = np.c_[interp_lons, interp_lats] + +sens_mean_err_all_sims = [] +wang_mean_err_all_sims = [] +true_max_amplitudes_full = [] +sens_max_amplitudes_full = [] +wang_max_amplitudes_full = [] +true_arrival_times_full = [] +sens_arrival_times_full = [] +wang_arrival_times_full = [] + +for sim_num in range(11): + display_num = sim_num + 1 + print("########## Simulation {} ##########".format(display_num)) + real_sens_left_waveforms = np.zeros(shape=(6,289)) + for i in range(6): + real_sens_left_waveforms[i] = all_inner_sens_vals[sim_num][five_smallest_indices[0][i]] + + real_sens_right_waveforms = np.zeros(shape=(6,289)) + for i in range(6): + real_sens_right_waveforms[i] = all_inner_sens_vals[sim_num][five_smallest_indices[1][i]] + + ### GET THE ARRIVAL TIMES FOR REAL WAVEFORMS ### + real_sens_left_arrival_times = np.zeros(6) + real_sens_left_arrival_time_indices = np.zeros(6) + for i in range(6): + t = 0 + idx = 0 + for val in real_sens_left_waveforms[i]: + if val < 3e-2: + t += 50.0 + idx += 1 + else: + real_sens_left_arrival_times[i] += t + real_sens_left_arrival_time_indices[i] += idx + break + + real_sens_right_arrival_times = np.zeros(6) + real_sens_right_arrival_time_indices = np.zeros(6) + for i in range(6): + t = 0 + idx = 0 + for val in real_sens_right_waveforms[i]: + if val < 3e-2: + t += 50.0 + idx += 1 + else: + real_sens_right_arrival_times[i] += t + real_sens_right_arrival_time_indices[i] += idx + break + + wang_arrival_times = .5*real_sens_left_arrival_times+.5*real_sens_right_arrival_times + + t_seconds = np.arange(0, 50 * 289, 50) + t_minutes = t_seconds/60 + wang_arrival_time_indices = np.zeros(6) + for i in range(6): + wang_arrival_time_indices[i] += np.argmin(np.abs(wang_arrival_times[i]-t_seconds)) + + interpolated_virtual_sens_waveforms = np.zeros(shape=(6,289)) + stop_plot_indices = np.zeros(6) + for i in range(6): + l_idx = int(real_sens_left_arrival_time_indices[i]) + r_idx = int(real_sens_right_arrival_time_indices[i]) + v_idx = int(wang_arrival_time_indices[i]) + smaller_idx = min(l_idx,r_idx) + res = int(289 - v_idx) + left_wf_shifted = real_sens_left_waveforms[i][smaller_idx:smaller_idx+res] + right_wf_shifted = real_sens_right_waveforms[i][smaller_idx:smaller_idx+res] + interpolated_virtual_sens_waveforms[i][v_idx:v_idx+res] = (((.5*left_wf_shifted)* + ((-real_sens_left_bath[i])**(.25))+ + (.5*right_wf_shifted)* + ((-real_sens_right_bath[i])**(.25))) + /((-virtual_sens_bath[i])**(.25))) + + + + real_virtual_sens_waveforms = np.zeros(shape=(6, 289)) + for i in range(6): + vlon = virtual_lons[i] + vlat = virtual_lats[i] + for j in range(289): + tsu = LinearNDInterpolator(xy_inner, all_true_for_interp[sim_num][j,:,0]) + real_virtual_sens_waveforms[i,j] = tsu(vlon,vlat) + + senseiver_virtual_sens_waveforms = np.zeros(shape=(6, 289)) + for i in range(6): + vlon = virtual_lons[i] + vlat = virtual_lats[i] + for j in range(289): + senseiver = LinearNDInterpolator(xy_inner, all_senseiver_for_interp[sim_num][j,:,0]) + senseiver_virtual_sens_waveforms[i,j] = senseiver(vlon,vlat) + + #filter the senseiver recons + for i in range(6): + # SAME AS COMBINED PLOT FILTERING v2 + signal = senseiver_virtual_sens_waveforms[i] + filtered_signal = medfilt(signal, 9) + filtered_signal = medfilt(filtered_signal, 3) + senseiver_virtual_sens_waveforms[i] = filtered_signal + + true_arrival_times = np.zeros(6) + true_arrival_time_indices = np.zeros(6) + for i in range(6): + t = 0 + idx = 0 + for val in real_virtual_sens_waveforms[i]: + if val < 3e-2: + t += 50.0 + idx += 1 + else: + true_arrival_times[i] += t + true_arrival_time_indices[i] += idx + break + + true_max_amplitudes = np.zeros(6) + for i in range(6): + true_max_amplitudes[i] = np.max(np.abs(real_virtual_sens_waveforms[i])) + + sens_arrival_times = np.zeros(6) + sens_arrival_time_indices = np.zeros(6) + for i in range(6): + t = 0 + idx = 0 + for val in senseiver_virtual_sens_waveforms[i]: + if val < 3e-2: + t += 50.0 + idx += 1 + else: + sens_arrival_times[i] += t + sens_arrival_time_indices[i] += idx + break + + senseiver_max_amplitudes = np.zeros(6) + for i in range(6): + senseiver_max_amplitudes[i] = np.max(np.abs(senseiver_virtual_sens_waveforms[i])) + + wang_max_amplitudes = np.zeros(6) + for i in range(6): + wang_max_amplitudes[i] = np.max(np.abs(interpolated_virtual_sens_waveforms[i])) + + print("Left Sens Arrival Times: ", real_sens_left_arrival_times/60) + print("Right Sens Arrival Times: ", real_sens_right_arrival_times/60) + print("Virtual Arrival Times (True): ", true_arrival_times/60) + print("Virtual Arrival Times (Wang): ", wang_arrival_times/60) + print("Virtual Arrival Times (Sens): ", sens_arrival_times/60) + print("Virtual Max Amp (True): ", true_max_amplitudes) + print("Virtual Max Amp (Wang): ", wang_max_amplitudes) + print("Virtual Max Amp (Sens): ", senseiver_max_amplitudes) + print("Wang AT MAE: ",np.mean(np.abs(true_arrival_times-wang_arrival_times))/60) + print("Senseiver AT MAE: ",np.mean(np.abs(true_arrival_times-sens_arrival_times))/60) + print("Wang MA MAE: ",np.mean(np.abs(true_max_amplitudes-wang_max_amplitudes))) + print("Senseiver MA MAE: ",np.mean(np.abs(true_max_amplitudes-senseiver_max_amplitudes))) + true_max_amplitudes_full.append(true_max_amplitudes) + sens_max_amplitudes_full.append(senseiver_max_amplitudes) + wang_max_amplitudes_full.append(wang_max_amplitudes) + true_arrival_times_full.append(true_arrival_times/60) + sens_arrival_times_full.append(sens_arrival_times/60) + wang_arrival_times_full.append(wang_arrival_times/60) + + senseiver_mean_arr = [] + wang_mean_arr = [] + times = (5/6)*np.arange(0, 289, 1) + + + + for i in range(6): + plt.rcParams["font.size"] = 16 + plt.rcParams['xtick.labelsize'] = 16 + plt.rcParams['ytick.labelsize'] = 16 + fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(9, 3)) + axs[0].plot(times, real_virtual_sens_waveforms[i], c='k') + axs[0].plot(times, senseiver_virtual_sens_waveforms[i], c='g') + axs[0].plot(times, interpolated_virtual_sens_waveforms[i], c='r') + axs[0].legend(['SWE', 'Senseiver', 'LIHFP']) + txt = r'\textbf{Height at Sensor }' + r'\textbf{' + str(i + 1) + r'}' + txt2 = r'\textbf{Absolute Error at Sensor }' + r'\textbf{' + str(i + 1) + r'}' + axs[0].set(xlabel=r"\textbf{Time (minutes)}", ylabel=txt) + axs[0].set_xlim(0, 240) + axs[1].plot(times, np.abs(real_virtual_sens_waveforms[i] - senseiver_virtual_sens_waveforms[i]), c='k') + axs[1].plot(times, np.abs(real_virtual_sens_waveforms[i] - interpolated_virtual_sens_waveforms[i]), c='r') + axs[1].legend([r'\textbf{Senseiver: MAE = }' + r'\textbf{' + str( + round(np.mean(np.abs(real_virtual_sens_waveforms[i] - senseiver_virtual_sens_waveforms[i])), 2)) + r'}', + r'\textbf{LIHFP: MAE = }' + r'\textbf{' + str(round( + np.mean(np.abs(real_virtual_sens_waveforms[i] - interpolated_virtual_sens_waveforms[i])), + 2)) + r'}']) + axs[1].set(xlabel=r"\textbf{Time (minutes)}", ylabel=txt2) + axs[1].set_xlim(0, 240) + plt.tight_layout() + plt.savefig(os.getcwd()+"/virtual_waveforms" + fig_dir + + "/sim_{}_senslon_{}_senslat_{}_sensnum_{}_v2.png".format( + display_num, + round(virtual_lons[i], 2), + round(virtual_lats[i], 2), i + 1), bbox_inches='tight', dpi=400 + ) + plt.close() + + fig, ax = plt.subplots(figsize=(10, 3.25)) + ax.plot(times, real_virtual_sens_waveforms[i], c='k') + ax.plot(times, senseiver_virtual_sens_waveforms[i], c='#0072B2') + ax.plot(times, interpolated_virtual_sens_waveforms[i], c='#D55E00') + txt = r'\textbf{Height at Sensor }' + r'\textbf{' + str(i + 1) + r'}' + txt2 = r'\textbf{Absolute Error at Sensor }' + r'\textbf{' + str(i + 1) + r'}' + ax.set(xlabel=r"\textbf{Time (minutes)}") + ax.set_ylabel(txt) + ax.yaxis.set_label_position("right") + ax.yaxis.tick_right() + ax.set_xlim(0, 240) + ax.legend([r'\textbf{SWE}', + r'\textbf{Senseiver: MAE = }' + r'\textbf{' + str( + round(np.mean(np.abs(real_virtual_sens_waveforms[i] - + senseiver_virtual_sens_waveforms[i])), 2)) + r'}', + r'\textbf{LIHFP: MAE = }' + r'\textbf{' + str(round( + np.mean(np.abs(real_virtual_sens_waveforms[i] - + interpolated_virtual_sens_waveforms[i])), + 2)) + r'}']) + for label in ax.get_xticklabels(): + label.set_fontweight('bold') + for label in ax.get_yticklabels(): + label.set_fontweight('bold') + plt.tight_layout() + plt.savefig(os.getcwd()+ "/virtual_waveforms" + fig_dir + + "/single_sim_{}_senslon_{}_senslat_{}_sensnum_{}_v2.png".format( + display_num, + round(virtual_lons[i], 2), + round(virtual_lats[i], 2), i + 1), bbox_inches='tight', dpi=400 + ) + plt.close() + + + print("Senseiver Mean Error for sensor {}:".format(i+1),np.mean(np.abs(real_virtual_sens_waveforms[i]-senseiver_virtual_sens_waveforms[i]))) + print("Wang Mean Error for sensor {}:".format(i+1),np.mean(np.abs(real_virtual_sens_waveforms[i]-interpolated_virtual_sens_waveforms[i]))) + senseiver_mean_arr.append(np.mean(np.abs(real_virtual_sens_waveforms[i]-senseiver_virtual_sens_waveforms[i]))) + wang_mean_arr.append(np.mean(np.abs(real_virtual_sens_waveforms[i]-interpolated_virtual_sens_waveforms[i]))) + print("Mean Senseiver Error for Sim {}:".format(display_num), + np.mean(senseiver_mean_arr)) + print("Mean Wang Error for Sim {}:".format(display_num), + np.mean(wang_mean_arr)) + sens_mean_err_all_sims.append(np.mean(senseiver_mean_arr)) + wang_mean_err_all_sims.append(np.mean(wang_mean_arr)) + +print("Senseiver Mean Error for all Sims: ",np.mean(sens_mean_err_all_sims)) +print("Wang Mean Error for all Sims: ",np.mean(wang_mean_err_all_sims)) +print("Senseiver Error STD for all Sims: ",np.std(sens_mean_err_all_sims)) +print("Wang Error STD for all Sims: ",np.std(wang_mean_err_all_sims)) + + + +plt.rcParams["font.size"] = 20 +plt.rcParams['xtick.labelsize'] = 20 +plt.rcParams['ytick.labelsize'] = 20 + + +true_max_amp_mat = np.zeros(shape=(11,6)) +for i in range(11): + for j in range(6): + true_max_amp_mat[i,j] = true_max_amplitudes_full[i][j] + +wang_max_amp_mat = np.zeros(shape=(11,6)) +for i in range(11): + for j in range(6): + wang_max_amp_mat[i,j] = wang_max_amplitudes_full[i][j] + +sens_max_amp_mat = np.zeros(shape=(11,6)) +for i in range(11): + for j in range(6): + sens_max_amp_mat[i,j] = sens_max_amplitudes_full[i][j] + +fig, axs = plt.subplots(nrows=1,ncols=3,figsize=(8,4)) +im1 = axs[0].imshow(true_max_amp_mat,cmap='bwr') +axs[0].set_title(r'\textbf{True}') +axs[0].set_xlabel(r'\textbf{Virtual Sensors}') +axs[0].set_ylabel(r'\textbf{Simulations}') +im2 = axs[1].imshow(wang_max_amp_mat,cmap='bwr') +axs[1].set_title(r'\textbf{LIHFP}') +im3 = axs[2].imshow(sens_max_amp_mat,cmap='bwr') +axs[2].set_title(r'\textbf{Senseiver}') +pos = axs[2].get_position() +cax = fig.add_axes([pos.x1 + 0.02, pos.y0, 0.02, pos.height]) +cbar = plt.colorbar(im3, cax=cax) +cbar.set_label(r'\textbf{Max Amplitudes (m)}') + + +fig.suptitle(r'\textbf{Max Amplitudes}',y=.95) +plt.savefig(os.getcwd()+ "/virtual_waveforms" + fig_dir+ + "/MaxAmpMat_v2.png",dpi=400) + +true_arr_time_mat = np.zeros(shape=(11,6)) +for i in range(11): + for j in range(6): + true_arr_time_mat[i, j] = true_arrival_times_full[i][j] + +wang_arr_time_mat = np.zeros(shape=(11,6)) +for i in range(11): + for j in range(6): + wang_arr_time_mat[i, j] = wang_arrival_times_full[i][j] + +sens_arr_time_mat = np.zeros(shape=(11,6)) +for i in range(11): + for j in range(6): + sens_arr_time_mat[i, j] = sens_arrival_times_full[i][j] + +fig, axs = plt.subplots(nrows=1,ncols=3,figsize=(8,4)) +im1 = axs[0].imshow(true_arr_time_mat,cmap='bwr') +axs[0].set_title(r'\textbf{True}') +axs[0].set_xlabel(r'\textbf{Virtual Sensors}') +axs[0].set_ylabel(r'\textbf{Simulations}') +im2 = axs[1].imshow(wang_arr_time_mat,cmap='bwr') +axs[1].set_title(r'\textbf{LIHFP}') +im3 = axs[2].imshow(sens_arr_time_mat,cmap='bwr') +axs[2].set_title(r'\textbf{Senseiver}') + +pos = axs[2].get_position() +cax = fig.add_axes([pos.x1 + 0.02, pos.y0, 0.02, pos.height]) +cbar = plt.colorbar(im3, cax=cax) +cbar.set_label(r'\textbf{Arrival Times (mins)}') + + +fig.suptitle(r'\textbf{Arrival Times}',y=.95) +plt.savefig(os.getcwd()+"/virtual_waveforms/" + fig_dir+ + "/ArrivalTimeMat_v2.png",dpi=400 + ) + +wang_ma_err = np.abs(true_max_amp_mat-wang_max_amp_mat) +sens_ma_err = np.abs(true_max_amp_mat-sens_max_amp_mat) +vmin = min(np.nanmin(wang_ma_err), np.nanmin(sens_ma_err)) +vmax = max(np.nanmax(wang_ma_err), np.nanmax(sens_ma_err)) + +fig, axs = plt.subplots(nrows=1,ncols=2,figsize=(5,6.15),constrained_layout=True) +im1 = axs[0].imshow(wang_ma_err, cmap='coolwarm', vmin=vmin, vmax=vmax) +axs[0].set_ylabel(r'\textbf{LIHFP Avg: }' + r'\textbf{'+str(round(np.mean(wang_ma_err),1))+r'}'r'\textbf{ (m)}') +im2 = axs[1].imshow(sens_ma_err, cmap='coolwarm', vmin=vmin, vmax=vmax) +axs[1].set_ylabel(r'\textbf{Senseiver Avg: }'+ r'\textbf{'+str(round(np.mean(sens_ma_err),1))+r'}'+r'\textbf{ (m)}') +pos = axs[1].get_position() +cbar = fig.colorbar(im1, ax=axs, orientation='horizontal', fraction=0.05, pad=0.02) +ticks = np.linspace(vmin, vmax, num=3) +ticklabels = [r'\textbf{%.2f}' % t for t in ticks] +cbar.set_ticks(ticks) +cbar.set_ticklabels(ticklabels) +cbar.set_label(r'\textbf{Max Amplitude MAE}') +axs[0].set_xticks([]) +axs[0].set_yticks([]) +axs[1].yaxis.tick_right() +axs[1].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False) +axs[1].set_yticks(np.arange(11)) +axs[1].set_xticks(np.arange(6)) +axs[1].set_yticklabels([r'\textbf{%d}' % (i+1) for i in range(11)]) +axs[1].set_xticklabels([r'\textbf{%d}' % (i+1) for i in range(6)]) +plt.savefig(os.getcwd()+ "/virtual_waveforms" + fig_dir+ + "/MaxAmpMatErr_v2.png",bbox_inches='tight',dpi=400) + + + +wang_at_err = np.abs(true_arr_time_mat-wang_arr_time_mat) +sens_at_err = np.abs(true_arr_time_mat-sens_arr_time_mat) +vmin = min(np.nanmin(wang_at_err), np.nanmin(sens_at_err)) +vmax = max(np.nanmax(wang_at_err), np.nanmax(sens_at_err)) + +fig, axs = plt.subplots(nrows=1,ncols=2,figsize=(5,6.15),constrained_layout=True) +im1 = axs[0].imshow(wang_at_err, cmap='coolwarm', vmin=vmin, vmax=vmax) +axs[0].set_ylabel(r'\textbf{LIHFP Avg: }' + r'\textbf{'+str(round(np.mean(wang_at_err),1))+r'}'r'\textbf{ (mins)}') +im2 = axs[1].imshow(sens_at_err, cmap='coolwarm', vmin=vmin, vmax=vmax) +axs[1].set_ylabel(r'\textbf{Senseiver Avg: }' + r'\textbf{'+str(round(np.mean(sens_at_err),1))+r'}'r'\textbf{ (mins)}') +pos = axs[1].get_position() +cbar = fig.colorbar(im1, ax=axs, orientation='horizontal', fraction=0.05, pad=0.02) +ticks = np.linspace(vmin, vmax, num=3) +ticklabels = [r'\textbf{%.2f}' % t for t in ticks] +cbar.set_ticks(ticks) +cbar.set_ticklabels(ticklabels) +cbar.set_label(r'\textbf{Arrival Time MAE}') +axs[0].set_xticks([]) +axs[0].set_yticks([]) +axs[1].yaxis.tick_right() +axs[1].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False) +axs[1].set_yticks(np.arange(11)) +axs[1].set_xticks(np.arange(6)) +axs[1].set_yticklabels([r'\textbf{%d}' % (i+1) for i in range(11)]) +axs[1].set_xticklabels([r'\textbf{%d}' % (i+1) for i in range(6)]) +plt.savefig(os.getcwd()+ "/virtual_waveforms" + fig_dir+ + "/ArrivalTimeMatErr_v2.png",bbox_inches='tight',dpi=400) + +true_max_amplitudes_full = np.array(true_max_amplitudes_full).flatten() +sens_max_amplitudes_full = np.array(sens_max_amplitudes_full).flatten() +wang_max_amplitudes_full = np.array(wang_max_amplitudes_full).flatten() +true_arrival_times_full = np.array(true_arrival_times_full).flatten() +sens_arrival_times_full = np.array(sens_arrival_times_full).flatten() +wang_arrival_times_full = np.array(wang_arrival_times_full).flatten() + +print("Senseiver Max Amp MAE for all sims: ",np.mean(np.abs(true_max_amplitudes_full-sens_max_amplitudes_full))) +print("Wang Max Amp MAE for all sims: ",np.mean(np.abs(true_max_amplitudes_full-wang_max_amplitudes_full))) +print("Senseiver Arrival Time MAE for all sims: ",np.mean(np.abs(true_arrival_times_full-sens_arrival_times_full))) +print("Wang Arrival Time MAE for all sims: ",np.mean(np.abs(true_arrival_times_full-wang_arrival_times_full))) +logfile.close() diff --git a/tsunami/quick_plot_combined.py b/tsunami/quick_plot_combined.py index 34e3637..32fa02c 100644 --- a/tsunami/quick_plot_combined.py +++ b/tsunami/quick_plot_combined.py @@ -24,27 +24,17 @@ # load the simulation data and create a dataloader dataloader1 = senseiver_dataloader(data_config, num_workers=4) - data_config2 = data_config -data_config2['data_key'] = "unseen_agg_8_sims_0_time_ss_2_ss_unstruct_ntimes_3464_wd_new_epis_tf_289.mat" +data_config2['data_key'] = "unseen_agg_11_sims_0_time_ss_2_ss_unstruct_ntimes_3179_wd_new_epis_v7_tf_289.mat" dataloader2 = senseiver_dataloader(data_config2, num_workers=4) model_num1 = encoder_config['load_model_num'] -if model_num1 == 34: - model_num2 = 22 -else: - model_num2 = 2 +model_num2 = 125 #use the checkpoint number corresponding to desired model print(f'Loading {model_num1} ...') - model_loc1 = gb(f"lightning_logs/version_{model_num1}/checkpoints/*.ckpt")[0] -# Use the below commented code if using on HPC -# model = Senseiver.load_from_checkpoint(model_loc, -# **encoder_config, -# **decoder_config, -# **data_config) model1 = Senseiver.load_from_checkpoint(model_loc1, map_location=torch.device('cpu'), **encoder_config, **decoder_config, @@ -53,11 +43,6 @@ print(f'Loading {model_num2} ...') model_loc2 = gb(f"lightning_logs/version_{model_num2}/checkpoints/*.ckpt")[0] -# Use the below commented code if using on HPC -# model = Senseiver.load_from_checkpoint(model_loc, -# **encoder_config, -# **decoder_config, -# **data_config) model2 = Senseiver.load_from_checkpoint(model_loc2, map_location=torch.device('cpu'), **encoder_config, **decoder_config, @@ -76,10 +61,7 @@ output_im1 = torch.load(os.getcwd() + "/lightning_logs/version_" + str(model_num1) + '/tensor' + '_training' + '.pt') output_im2 = torch.load(os.getcwd() + "/lightning_logs/version_" + str(model_num2) + '/tensor' + '_training' + '.pt') -if encoder_config['load_model_num'] == 34: - path = "/Users/emcdugald/sparse_sens_tsunami/lightning_logs/combined_34_22/" -else: - path = "/Users/emcdugald/sparse_sens_tsunami/lightning_logs/combined_0_2/" +path = os.getcwd()+"/lightning_logs/combined_122_125" if unseen_flag: name = 'tensor_unseen.pt' From 3fd4be4ea487edabaf692b5444bdc0db77c6463f Mon Sep 17 00:00:00 2001 From: Edward McDugald Date: Wed, 16 Jul 2025 23:22:41 -0700 Subject: [PATCH 4/4] minor code change --- tsunami/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tsunami/datasets.py b/tsunami/datasets.py index 7c5a13c..08cb8c8 100644 --- a/tsunami/datasets.py +++ b/tsunami/datasets.py @@ -39,7 +39,7 @@ def SWETsunamiWdiv(fname): def SWETsunamiForPlotting2(fname): - fpath = "/Users/emcdugald/sparse_sens_tsunami/Data/tsunami/"+fname + fpath = os.getcwd()+"/Data/tsunami/"+fname data = sio.loadmat(fpath) wave_height = data['zt'] max_wave_height = np.max(np.abs(wave_height))