diff --git a/plotting/figures/change_weight_ppo.png b/plotting/figures/change_weight_ppo.png deleted file mode 100644 index d3d4c1d..0000000 Binary files a/plotting/figures/change_weight_ppo.png and /dev/null differ diff --git a/plotting/figures/return.png b/plotting/figures/return.png deleted file mode 100644 index fd745ef..0000000 Binary files a/plotting/figures/return.png and /dev/null differ diff --git a/plotting/figures/return_change_weight_ppo.png b/plotting/figures/return_change_weight_ppo.png deleted file mode 100644 index a7dc184..0000000 Binary files a/plotting/figures/return_change_weight_ppo.png and /dev/null differ diff --git a/plotting/figures/return_change_weight_ppo_eps=1e-3.png b/plotting/figures/return_change_weight_ppo_eps=1e-3.png deleted file mode 100644 index 36aa60d..0000000 Binary files a/plotting/figures/return_change_weight_ppo_eps=1e-3.png and /dev/null differ diff --git a/plotting/figures/return_double_weight_ppo.png b/plotting/figures/return_double_weight_ppo.png deleted file mode 100644 index 9c4bbb5..0000000 Binary files a/plotting/figures/return_double_weight_ppo.png and /dev/null differ diff --git a/plotting/figures/return_dro.png b/plotting/figures/return_dro.png deleted file mode 100644 index da0bee0..0000000 Binary files a/plotting/figures/return_dro.png and /dev/null differ diff --git a/plotting/figures/return_fixedweight_ppo.png b/plotting/figures/return_fixedweight_ppo.png deleted file mode 100644 index 5fc4121..0000000 Binary files a/plotting/figures/return_fixedweight_ppo.png and /dev/null differ diff --git a/plotting/figures/return_noavg.png b/plotting/figures/return_noavg.png deleted file mode 100644 index 2685f10..0000000 Binary files a/plotting/figures/return_noavg.png and /dev/null differ diff --git a/plotting/figures/return_sweep.png b/plotting/figures/return_sweep.png deleted file mode 100755 index 2a98232..0000000 Binary files a/plotting/figures/return_sweep.png and /dev/null differ diff --git a/plotting/figures/saved_fig/bandit_200_actions_conv.png b/plotting/figures/saved_fig/bandit_200_actions_conv.png deleted file mode 100644 index 153f9b4..0000000 Binary files a/plotting/figures/saved_fig/bandit_200_actions_conv.png and /dev/null differ diff --git a/plotting/figures/saved_fig/return_change_weight_ppo_eps=1e-2.png b/plotting/figures/saved_fig/return_change_weight_ppo_eps=1e-2.png deleted file mode 100644 index e4c17d6..0000000 Binary files a/plotting/figures/saved_fig/return_change_weight_ppo_eps=1e-2.png and /dev/null differ diff --git a/plotting/figures/saved_fig/return_double_weight_ppo_100_actions.png b/plotting/figures/saved_fig/return_double_weight_ppo_100_actions.png deleted file mode 100644 index 57c671c..0000000 Binary files a/plotting/figures/saved_fig/return_double_weight_ppo_100_actions.png and /dev/null differ diff --git a/plotting/figures/saved_fig/return_double_weight_ppo_200_actions.png b/plotting/figures/saved_fig/return_double_weight_ppo_200_actions.png deleted file mode 100644 index 37592e8..0000000 Binary files a/plotting/figures/saved_fig/return_double_weight_ppo_200_actions.png and /dev/null differ diff --git a/plotting/figures/saved_fig/return_double_weight_ppo_50_actions.png b/plotting/figures/saved_fig/return_double_weight_ppo_50_actions.png deleted file mode 100644 index 84040bc..0000000 Binary files a/plotting/figures/saved_fig/return_double_weight_ppo_50_actions.png and /dev/null differ diff --git a/plotting/figures/saved_fig/return_double_weight_ppo_std=0.5.png b/plotting/figures/saved_fig/return_double_weight_ppo_std=0.5.png deleted file mode 100644 index 236a6c3..0000000 Binary files a/plotting/figures/saved_fig/return_double_weight_ppo_std=0.5.png and /dev/null differ diff --git a/plotting/figures/saved_fig/return_noavg_gym100.png b/plotting/figures/saved_fig/return_noavg_gym100.png deleted file mode 100644 index d2c298e..0000000 Binary files a/plotting/figures/saved_fig/return_noavg_gym100.png and /dev/null differ diff --git a/plotting/gridworld/success_avg.py b/plotting/gridworld/success_avg.py new file mode 100644 index 0000000..3e64caa --- /dev/null +++ b/plotting/gridworld/success_avg.py @@ -0,0 +1,102 @@ +import os + +import numpy as np +import seaborn + +import matplotlib +matplotlib.use('TkAgg') +import matplotlib.pyplot as plt + +from rliable import library as rly +from rliable import metrics +# from rliable.plot_utils import plot_sample_efficiency_curve +from plotting.utils import plot_sample_efficiency_curve + +from plotting.utils import get_data + +def plot(x_dict, y_dict, linestyle_dict, color_dict): + results_dict = {algorithm: score for algorithm, score in y_dict.items()} + aggr_func = lambda scores: np.array([metrics.aggregate_mean([scores[..., frame]]) for frame in range(scores.shape[-1])]) + scores, cis = rly.get_interval_estimates(results_dict, aggr_func, reps=1000) + + plot_sample_efficiency_curve( + frames=x_dict, + point_estimates=scores, + interval_estimates=cis, + ax=ax, + algorithms=None, + xlabel='Timestep', + ylabel=f'Success Rate', + labelsize='large', + ticklabelsize='large', + linestyles=linestyle_dict, + colors=color_dict, + marker='', + ) + plt.legend() + plt.title(f'Average success rate over all tasks', fontsize='large') + plt.ylim(0,1.1) + + # Use scientific notation for x-axis + plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) + # set fontsize of scientific notation label + ax.xaxis.get_offset_text().set_fontsize('large') + + plt.tight_layout() + +if __name__ == "__main__": + + n_rows = 1 + n_cols = 1 + fig = plt.figure(figsize=(n_cols*4,n_rows*4)) + i = 1 + + x_dict, y_dict, linestyle_dict, color_dict = {}, {}, {}, {} + + ax = plt.subplot(n_rows, n_cols, i) + i+=1 + + lr = 3e-3 + ns = 256 + path_dict = { + 'DRO': f"../chtc/results1/gw4/results/ppo/dro/lr_{lr}/ns_{ns}", + 'DRO, reweight': f"../chtc/results1/gw4/results/ppo/dro_reweight/lr_{lr}/ns_{ns}", + 'Hard First': f"../chtc/results1/gw4/results/ppo/hard_first/lr_{lr}/ns_{ns}", + 'Easy First': f"../chtc/results1/gw4/results/ppo/easy_first/lr_{lr}/ns_{ns}", + 'Learning Progress': f"../chtc/results1/gw4/results/ppo/learning_progress/lr_{lr}/ns_{ns}", + 'Uniform': f"../chtc/results1/gw4/results/ppo/uniform/lr_{lr}/ns_{ns}", + } + color_palette = iter(seaborn.color_palette('colorblind', n_colors=10)) + + for key, results_dir in path_dict.items(): + x, y = get_data(results_dir, y_name=f'success_rate') + T = 100 # reduce this value if you want to truncate the data + if y is not None: + x_dict[key] = x[:T] + y_dict[key] = y[:, :T] + linestyle_dict[key] = '-' + color_dict[key] = next(color_palette) + + plot(x_dict, y_dict, linestyle_dict, color_dict) + x_dict, y_dict, linestyle_dict, color_dict = {}, {}, {}, {} + + # plt.axhline(y=1, color='k', linestyle='--', label='Optimal\nsuccess rate') + plt.axhline(y=1, color='k', linestyle='--') + + plt.legend(fontsize='large') + # # Push plots down to make room for the the legend + # fig.subplots_adjust(top=0.86) + # + # # Fetch and plot the legend from one of the subplots. + # ax = fig.axes[0] + # handles, labels = ax.get_legend_handles_labels() + # fig.legend(handles, labels, loc='upper center', fontsize='large', ncols=2) + + save_dir = f'figures' + save_name = f'success_rate_avg.png' + os.makedirs(save_dir, exist_ok=True) + plt.savefig(f'{save_dir}/{save_name}', dpi=200) + + plt.show() + + diff --git a/plotting/gridworld/success_tasks.py b/plotting/gridworld/success_tasks.py new file mode 100644 index 0000000..0de372c --- /dev/null +++ b/plotting/gridworld/success_tasks.py @@ -0,0 +1,86 @@ +import os +import numpy as np +import seaborn + +import matplotlib +matplotlib.use('TkAgg') +import matplotlib.pyplot as plt + +from rliable import library as rly +from rliable import metrics +from utils import plot_sample_efficiency_curve +from utils import get_data + + +def plot_metric_on_ax(ax, path_dict, linestyle_dict, color_dict, metric_name, title): + x_dict, y_dict = {}, {} + + for key, results_dir in path_dict.items(): + x, y = get_data(results_dir, y_name=metric_name) + if y is not None: + x_dict[key] = x + y_dict[key] = y + + results_dict = {algorithm: score for algorithm, score in y_dict.items()} + aggr_func = lambda scores: np.array( + [metrics.aggregate_mean([scores[..., frame]]) for frame in range(scores.shape[-1])] + ) + scores, cis = rly.get_interval_estimates(results_dict, aggr_func, reps=100) + + plot_sample_efficiency_curve( + frames=x_dict, + point_estimates=scores, + interval_estimates=cis, + ax=ax, + algorithms=None, + xlabel='Timestep', + ylabel='Success rate', + labelsize='large', + ticklabelsize='large', + linestyles=linestyle_dict, + colors=color_dict, + marker='', + ) + + ax.set_title(title, fontsize='large') + ax.set_ylim(0, 1.05) + ax.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) + ax.xaxis.get_offset_text().set_fontsize('large') + + +if __name__ == "__main__": + + fig = plt.figure(figsize=(16,4)) + + lr = 3e-3 + ns = 256 + path_dict = { + 'DRO': f"../chtc/results1/gw4/results/ppo/dro/lr_{lr}/ns_{ns}", + 'DRO, reweight': f"../chtc/results1/gw4/results/ppo/dro_reweight/lr_{lr}/ns_{ns}", + 'Hard First': f"../chtc/results1/gw4/results/ppo/hard_first/lr_{lr}/ns_{ns}", + 'Easy First': f"../chtc/results1/gw4/results/ppo/easy_first/lr_{lr}/ns_{ns}", + 'Learning Progress': f"../chtc/results1/gw4/results/ppo/learning_progress/lr_{lr}/ns_{ns}", + 'Uniform': f"../chtc/results1/gw4/results/ppo/uniform/lr_{lr}/ns_{ns}", + } + palette = seaborn.color_palette('colorblind', n_colors=10) + color_dict = dict(zip(path_dict.keys(), palette)) + linestyle_dict = {k: '-' for k in path_dict.keys()} + + for task_i in range(4): + ax = plt.subplot(1,4,task_i+1) + plot_metric_on_ax( + ax, + path_dict, + linestyle_dict, + color_dict, + metric_name=f'success_rate_{task_i}', + title=f'Task {task_i}' + ) + + if task_i == 0: + ax.legend(fontsize='large') + + plt.tight_layout() + os.makedirs("figures", exist_ok=True) + plt.savefig("figures/per_task_success.png", dpi=200) + plt.show() \ No newline at end of file diff --git a/plotting/gridworld/task_probs.py b/plotting/gridworld/task_probs.py new file mode 100644 index 0000000..77509d7 --- /dev/null +++ b/plotting/gridworld/task_probs.py @@ -0,0 +1,86 @@ +import os +import numpy as np +import seaborn + +import matplotlib +matplotlib.use('TkAgg') +import matplotlib.pyplot as plt + +from rliable import library as rly +from rliable import metrics +from utils import plot_sample_efficiency_curve +from utils import get_data + + +def plot_metric_on_ax(ax, path_dict, linestyle_dict, color_dict, metric_name, title): + x_dict, y_dict = {}, {} + + for key, results_dir in path_dict.items(): + x, y = get_data(results_dir, y_name=metric_name) + if y is not None: + x_dict[key] = x + y_dict[key] = y + + results_dict = {algorithm: score for algorithm, score in y_dict.items()} + aggr_func = lambda scores: np.array( + [metrics.aggregate_mean([scores[..., frame]]) for frame in range(scores.shape[-1])] + ) + scores, cis = rly.get_interval_estimates(results_dict, aggr_func, reps=100) + + plot_sample_efficiency_curve( + frames=x_dict, + point_estimates=scores, + interval_estimates=cis, + ax=ax, + algorithms=None, + xlabel='Timestep', + ylabel='Probability', + labelsize='large', + ticklabelsize='large', + linestyles=linestyle_dict, + colors=color_dict, + marker='', + ) + + ax.set_title(title, fontsize='large') + ax.set_ylim(0, 1.05) + ax.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) + ax.xaxis.get_offset_text().set_fontsize('large') + + +if __name__ == "__main__": + + fig = plt.figure(figsize=(16,4)) + + lr = 3e-3 + ns = 256 + path_dict = { + 'DRO': f"../chtc/results1/gw4/results/ppo/dro/lr_{lr}/ns_{ns}", + 'DRO, reweight': f"../chtc/results1/gw4/results/ppo/dro_reweight/lr_{lr}/ns_{ns}", + 'Hard First': f"../chtc/results1/gw4/results/ppo/hard_first/lr_{lr}/ns_{ns}", + 'Easy First': f"../chtc/results1/gw4/results/ppo/easy_first/lr_{lr}/ns_{ns}", + 'Learning Progress': f"../chtc/results1/gw4/results/ppo/learning_progress/lr_{lr}/ns_{ns}", + 'Uniform': f"../chtc/results1/gw4/results/ppo/uniform/lr_{lr}/ns_{ns}", + } + palette = seaborn.color_palette('colorblind', n_colors=10) + color_dict = dict(zip(path_dict.keys(), palette)) + linestyle_dict = {k: '-' for k in path_dict.keys()} + + for task_i in range(4): + ax = plt.subplot(1,4,task_i+1) + plot_metric_on_ax( + ax, + path_dict, + linestyle_dict, + color_dict, + metric_name=f'task_probs_{task_i}', + title=f'Task {task_i} sampling prob' + ) + + if task_i == 0: + ax.legend(fontsize='large') + + plt.tight_layout() + os.makedirs("figures", exist_ok=True) + plt.savefig("figures/task_probs.png", dpi=200) + plt.show() \ No newline at end of file diff --git a/plotting/bandit/figures/interesting_result/return_avg_ablate.png b/plotting/old/bandit/figures/interesting_result/return_avg_ablate.png similarity index 100% rename from plotting/bandit/figures/interesting_result/return_avg_ablate.png rename to plotting/old/bandit/figures/interesting_result/return_avg_ablate.png diff --git a/plotting/bandit/figures/interesting_result/success_rate_and_weights.png b/plotting/old/bandit/figures/interesting_result/success_rate_and_weights.png similarity index 100% rename from plotting/bandit/figures/interesting_result/success_rate_and_weights.png rename to plotting/old/bandit/figures/interesting_result/success_rate_and_weights.png diff --git a/plotting/bandit/figures/interesting_result/success_rate_avg_ablate.png b/plotting/old/bandit/figures/interesting_result/success_rate_avg_ablate.png similarity index 100% rename from plotting/bandit/figures/interesting_result/success_rate_avg_ablate.png rename to plotting/old/bandit/figures/interesting_result/success_rate_avg_ablate.png diff --git a/plotting/bandit/figures/momentum/return_avg_ablate.png b/plotting/old/bandit/figures/momentum/return_avg_ablate.png similarity index 100% rename from plotting/bandit/figures/momentum/return_avg_ablate.png rename to plotting/old/bandit/figures/momentum/return_avg_ablate.png diff --git a/plotting/bandit/figures/momentum/success_rate_and_weights.png b/plotting/old/bandit/figures/momentum/success_rate_and_weights.png similarity index 100% rename from plotting/bandit/figures/momentum/success_rate_and_weights.png rename to plotting/old/bandit/figures/momentum/success_rate_and_weights.png diff --git a/plotting/bandit/figures/momentum/success_rate_avg_ablate.png b/plotting/old/bandit/figures/momentum/success_rate_avg_ablate.png similarity index 100% rename from plotting/bandit/figures/momentum/success_rate_avg_ablate.png rename to plotting/old/bandit/figures/momentum/success_rate_avg_ablate.png diff --git a/plotting/bandit/figures/no_momentum/return_avg_ablate.png b/plotting/old/bandit/figures/no_momentum/return_avg_ablate.png similarity index 100% rename from plotting/bandit/figures/no_momentum/return_avg_ablate.png rename to plotting/old/bandit/figures/no_momentum/return_avg_ablate.png diff --git a/plotting/bandit/figures/no_momentum/success_rate_and_weights.png b/plotting/old/bandit/figures/no_momentum/success_rate_and_weights.png similarity index 100% rename from plotting/bandit/figures/no_momentum/success_rate_and_weights.png rename to plotting/old/bandit/figures/no_momentum/success_rate_and_weights.png diff --git a/plotting/bandit/figures/no_momentum/success_rate_avg_ablate.png b/plotting/old/bandit/figures/no_momentum/success_rate_avg_ablate.png similarity index 100% rename from plotting/bandit/figures/no_momentum/success_rate_avg_ablate.png rename to plotting/old/bandit/figures/no_momentum/success_rate_avg_ablate.png diff --git a/plotting/bandit/figures/return_avg_ablate.png b/plotting/old/bandit/figures/return_avg_ablate.png similarity index 100% rename from plotting/bandit/figures/return_avg_ablate.png rename to plotting/old/bandit/figures/return_avg_ablate.png diff --git a/plotting/bandit/figures/single_task_return.png b/plotting/old/bandit/figures/single_task_return.png similarity index 100% rename from plotting/bandit/figures/single_task_return.png rename to plotting/old/bandit/figures/single_task_return.png diff --git a/plotting/bandit/figures/single_task_success.png b/plotting/old/bandit/figures/single_task_success.png similarity index 100% rename from plotting/bandit/figures/single_task_success.png rename to plotting/old/bandit/figures/single_task_success.png diff --git a/plotting/bandit/figures/success_rate_and_weights.png b/plotting/old/bandit/figures/success_rate_and_weights.png similarity index 100% rename from plotting/bandit/figures/success_rate_and_weights.png rename to plotting/old/bandit/figures/success_rate_and_weights.png diff --git a/plotting/bandit/figures/success_rate_avg_ablate.png b/plotting/old/bandit/figures/success_rate_avg_ablate.png similarity index 100% rename from plotting/bandit/figures/success_rate_avg_ablate.png rename to plotting/old/bandit/figures/success_rate_avg_ablate.png diff --git a/plotting/bandit/plot_return_avg_ablate.py b/plotting/old/bandit/plot_return_avg_ablate.py similarity index 100% rename from plotting/bandit/plot_return_avg_ablate.py rename to plotting/old/bandit/plot_return_avg_ablate.py diff --git a/plotting/bandit/plot_single_returns.py b/plotting/old/bandit/plot_single_returns.py similarity index 100% rename from plotting/bandit/plot_single_returns.py rename to plotting/old/bandit/plot_single_returns.py diff --git a/plotting/bandit/plot_single_success.py b/plotting/old/bandit/plot_single_success.py similarity index 100% rename from plotting/bandit/plot_single_success.py rename to plotting/old/bandit/plot_single_success.py diff --git a/plotting/bandit/plot_success_avg_ablate.py b/plotting/old/bandit/plot_success_avg_ablate.py similarity index 100% rename from plotting/bandit/plot_success_avg_ablate.py rename to plotting/old/bandit/plot_success_avg_ablate.py diff --git a/plotting/bandit/plot_task_success.py b/plotting/old/bandit/plot_task_success.py similarity index 100% rename from plotting/bandit/plot_task_success.py rename to plotting/old/bandit/plot_task_success.py diff --git a/plotting/bandit/temp.py b/plotting/old/bandit/temp.py similarity index 100% rename from plotting/bandit/temp.py rename to plotting/old/bandit/temp.py diff --git a/plotting/bandit/utils.py b/plotting/old/bandit/utils.py similarity index 100% rename from plotting/bandit/utils.py rename to plotting/old/bandit/utils.py diff --git a/plotting/goal2d/figures/interesting_result/return_avg_ablate.png b/plotting/old/goal2d/figures/interesting_result/return_avg_ablate.png similarity index 100% rename from plotting/goal2d/figures/interesting_result/return_avg_ablate.png rename to plotting/old/goal2d/figures/interesting_result/return_avg_ablate.png diff --git a/plotting/goal2d/figures/interesting_result/success_rate_and_weights.png b/plotting/old/goal2d/figures/interesting_result/success_rate_and_weights.png similarity index 100% rename from plotting/goal2d/figures/interesting_result/success_rate_and_weights.png rename to plotting/old/goal2d/figures/interesting_result/success_rate_and_weights.png diff --git a/plotting/goal2d/figures/interesting_result/success_rate_avg_ablate.png b/plotting/old/goal2d/figures/interesting_result/success_rate_avg_ablate.png similarity index 100% rename from plotting/goal2d/figures/interesting_result/success_rate_avg_ablate.png rename to plotting/old/goal2d/figures/interesting_result/success_rate_avg_ablate.png diff --git a/plotting/goal2d/figures/momentum/return_avg_ablate.png b/plotting/old/goal2d/figures/momentum/return_avg_ablate.png similarity index 100% rename from plotting/goal2d/figures/momentum/return_avg_ablate.png rename to plotting/old/goal2d/figures/momentum/return_avg_ablate.png diff --git a/plotting/goal2d/figures/momentum/success_rate_and_weights.png b/plotting/old/goal2d/figures/momentum/success_rate_and_weights.png similarity index 100% rename from plotting/goal2d/figures/momentum/success_rate_and_weights.png rename to plotting/old/goal2d/figures/momentum/success_rate_and_weights.png diff --git a/plotting/goal2d/figures/momentum/success_rate_avg_ablate.png b/plotting/old/goal2d/figures/momentum/success_rate_avg_ablate.png similarity index 100% rename from plotting/goal2d/figures/momentum/success_rate_avg_ablate.png rename to plotting/old/goal2d/figures/momentum/success_rate_avg_ablate.png diff --git a/plotting/goal2d/figures/no_momentum/return_avg_ablate.png b/plotting/old/goal2d/figures/no_momentum/return_avg_ablate.png similarity index 100% rename from plotting/goal2d/figures/no_momentum/return_avg_ablate.png rename to plotting/old/goal2d/figures/no_momentum/return_avg_ablate.png diff --git a/plotting/goal2d/figures/no_momentum/success_rate_and_weights.png b/plotting/old/goal2d/figures/no_momentum/success_rate_and_weights.png similarity index 100% rename from plotting/goal2d/figures/no_momentum/success_rate_and_weights.png rename to plotting/old/goal2d/figures/no_momentum/success_rate_and_weights.png diff --git a/plotting/goal2d/figures/no_momentum/success_rate_avg_ablate.png b/plotting/old/goal2d/figures/no_momentum/success_rate_avg_ablate.png similarity index 100% rename from plotting/goal2d/figures/no_momentum/success_rate_avg_ablate.png rename to plotting/old/goal2d/figures/no_momentum/success_rate_avg_ablate.png diff --git a/plotting/goal2d/figures/return_avg_ablate.png b/plotting/old/goal2d/figures/return_avg_ablate.png similarity index 100% rename from plotting/goal2d/figures/return_avg_ablate.png rename to plotting/old/goal2d/figures/return_avg_ablate.png diff --git a/plotting/goal2d/figures/single_task_return.png b/plotting/old/goal2d/figures/single_task_return.png similarity index 100% rename from plotting/goal2d/figures/single_task_return.png rename to plotting/old/goal2d/figures/single_task_return.png diff --git a/plotting/goal2d/figures/single_task_success.png b/plotting/old/goal2d/figures/single_task_success.png similarity index 100% rename from plotting/goal2d/figures/single_task_success.png rename to plotting/old/goal2d/figures/single_task_success.png diff --git a/plotting/goal2d/figures/success_rate_and_weights.png b/plotting/old/goal2d/figures/success_rate_and_weights.png similarity index 100% rename from plotting/goal2d/figures/success_rate_and_weights.png rename to plotting/old/goal2d/figures/success_rate_and_weights.png diff --git a/plotting/goal2d/figures/success_rate_avg_ablate.png b/plotting/old/goal2d/figures/success_rate_avg_ablate.png similarity index 100% rename from plotting/goal2d/figures/success_rate_avg_ablate.png rename to plotting/old/goal2d/figures/success_rate_avg_ablate.png diff --git a/plotting/goal2d/plot_return_avg_ablate.py b/plotting/old/goal2d/plot_return_avg_ablate.py similarity index 100% rename from plotting/goal2d/plot_return_avg_ablate.py rename to plotting/old/goal2d/plot_return_avg_ablate.py diff --git a/plotting/goal2d/plot_single_returns.py b/plotting/old/goal2d/plot_single_returns.py similarity index 100% rename from plotting/goal2d/plot_single_returns.py rename to plotting/old/goal2d/plot_single_returns.py diff --git a/plotting/goal2d/plot_single_success.py b/plotting/old/goal2d/plot_single_success.py similarity index 100% rename from plotting/goal2d/plot_single_success.py rename to plotting/old/goal2d/plot_single_success.py diff --git a/plotting/goal2d/plot_success_avg_ablate.py b/plotting/old/goal2d/plot_success_avg_ablate.py similarity index 100% rename from plotting/goal2d/plot_success_avg_ablate.py rename to plotting/old/goal2d/plot_success_avg_ablate.py diff --git a/plotting/goal2d/plot_task_success.py b/plotting/old/goal2d/plot_task_success.py similarity index 100% rename from plotting/goal2d/plot_task_success.py rename to plotting/old/goal2d/plot_task_success.py diff --git a/plotting/goal2d/temp.py b/plotting/old/goal2d/temp.py similarity index 100% rename from plotting/goal2d/temp.py rename to plotting/old/goal2d/temp.py diff --git a/plotting/goal2d/utils.py b/plotting/old/goal2d/utils.py similarity index 100% rename from plotting/goal2d/utils.py rename to plotting/old/goal2d/utils.py diff --git a/plotting/gridworld/figures/chessboard.png b/plotting/old/gridworld_old/figures/chessboard.png similarity index 100% rename from plotting/gridworld/figures/chessboard.png rename to plotting/old/gridworld_old/figures/chessboard.png diff --git a/plotting/gridworld/figures/dro_num_steps=128/dro_learning_rate=0.2/success_rate_and_weights.png b/plotting/old/gridworld_old/figures/dro_num_steps=128/dro_learning_rate=0.2/success_rate_and_weights.png similarity index 100% rename from plotting/gridworld/figures/dro_num_steps=128/dro_learning_rate=0.2/success_rate_and_weights.png rename to plotting/old/gridworld_old/figures/dro_num_steps=128/dro_learning_rate=0.2/success_rate_and_weights.png diff --git a/plotting/gridworld/figures/dro_num_steps=128/dro_learning_rate=0.2/success_rate_avg_ablate.png b/plotting/old/gridworld_old/figures/dro_num_steps=128/dro_learning_rate=0.2/success_rate_avg_ablate.png similarity index 100% rename from plotting/gridworld/figures/dro_num_steps=128/dro_learning_rate=0.2/success_rate_avg_ablate.png rename to plotting/old/gridworld_old/figures/dro_num_steps=128/dro_learning_rate=0.2/success_rate_avg_ablate.png diff --git a/plotting/gridworld/figures/dro_num_steps=128/dro_learning_rate=1.0/success_rate_and_weights.png b/plotting/old/gridworld_old/figures/dro_num_steps=128/dro_learning_rate=1.0/success_rate_and_weights.png similarity index 100% rename from plotting/gridworld/figures/dro_num_steps=128/dro_learning_rate=1.0/success_rate_and_weights.png rename to plotting/old/gridworld_old/figures/dro_num_steps=128/dro_learning_rate=1.0/success_rate_and_weights.png diff --git a/plotting/gridworld/figures/dro_num_steps=128/dro_learning_rate=1.0/success_rate_avg_ablate.png b/plotting/old/gridworld_old/figures/dro_num_steps=128/dro_learning_rate=1.0/success_rate_avg_ablate.png similarity index 100% rename from plotting/gridworld/figures/dro_num_steps=128/dro_learning_rate=1.0/success_rate_avg_ablate.png rename to plotting/old/gridworld_old/figures/dro_num_steps=128/dro_learning_rate=1.0/success_rate_avg_ablate.png diff --git a/plotting/gridworld/figures/success_rate_and_weights.png b/plotting/old/gridworld_old/figures/success_rate_and_weights.png similarity index 100% rename from plotting/gridworld/figures/success_rate_and_weights.png rename to plotting/old/gridworld_old/figures/success_rate_and_weights.png diff --git a/plotting/gridworld/figures/success_rate_avg_ablate.png b/plotting/old/gridworld_old/figures/success_rate_avg_ablate.png similarity index 100% rename from plotting/gridworld/figures/success_rate_avg_ablate.png rename to plotting/old/gridworld_old/figures/success_rate_avg_ablate.png diff --git a/plotting/gridworld/figures/success_rate_avg_ablate_for_diff_probs_init.png b/plotting/old/gridworld_old/figures/success_rate_avg_ablate_for_diff_probs_init.png similarity index 100% rename from plotting/gridworld/figures/success_rate_avg_ablate_for_diff_probs_init.png rename to plotting/old/gridworld_old/figures/success_rate_avg_ablate_for_diff_probs_init.png diff --git a/plotting/gridworld/plot_return_avg_ablate.py b/plotting/old/gridworld_old/plot_return_avg_ablate.py similarity index 100% rename from plotting/gridworld/plot_return_avg_ablate.py rename to plotting/old/gridworld_old/plot_return_avg_ablate.py diff --git a/plotting/gridworld/plot_single_success.py b/plotting/old/gridworld_old/plot_single_success.py similarity index 100% rename from plotting/gridworld/plot_single_success.py rename to plotting/old/gridworld_old/plot_single_success.py diff --git a/plotting/gridworld/plot_success_avg_ablate.py b/plotting/old/gridworld_old/plot_success_avg_ablate.py similarity index 100% rename from plotting/gridworld/plot_success_avg_ablate.py rename to plotting/old/gridworld_old/plot_success_avg_ablate.py diff --git a/plotting/gridworld/plot_task_success.py b/plotting/old/gridworld_old/plot_task_success.py similarity index 100% rename from plotting/gridworld/plot_task_success.py rename to plotting/old/gridworld_old/plot_task_success.py diff --git a/plotting/gridworld/temp.py b/plotting/old/gridworld_old/temp.py similarity index 100% rename from plotting/gridworld/temp.py rename to plotting/old/gridworld_old/temp.py diff --git a/plotting/gridworld/utils.py b/plotting/old/gridworld_old/utils.py similarity index 100% rename from plotting/gridworld/utils.py rename to plotting/old/gridworld_old/utils.py diff --git a/plotting/plot.py b/plotting/old/plot.py similarity index 100% rename from plotting/plot.py rename to plotting/old/plot.py diff --git a/plotting/plot_change_weight.py b/plotting/old/plot_change_weight.py similarity index 100% rename from plotting/plot_change_weight.py rename to plotting/old/plot_change_weight.py diff --git a/plotting/plot_ddpg.py b/plotting/old/plot_ddpg.py similarity index 100% rename from plotting/plot_ddpg.py rename to plotting/old/plot_ddpg.py diff --git a/plotting/plot_dro.py b/plotting/old/plot_dro.py similarity index 100% rename from plotting/plot_dro.py rename to plotting/old/plot_dro.py diff --git a/plotting/plot_fixedweight.py b/plotting/old/plot_fixedweight.py similarity index 100% rename from plotting/plot_fixedweight.py rename to plotting/old/plot_fixedweight.py diff --git a/plotting/plot_fw_ga.py b/plotting/old/plot_fw_ga.py similarity index 100% rename from plotting/plot_fw_ga.py rename to plotting/old/plot_fw_ga.py diff --git a/plotting/plot_no_avg.py b/plotting/old/plot_no_avg.py similarity index 100% rename from plotting/plot_no_avg.py rename to plotting/old/plot_no_avg.py diff --git a/plotting/plot_sweep.py b/plotting/old/plot_sweep.py similarity index 100% rename from plotting/plot_sweep.py rename to plotting/old/plot_sweep.py diff --git a/plotting/plot_updated.py b/plotting/old/plot_updated.py similarity index 100% rename from plotting/plot_updated.py rename to plotting/old/plot_updated.py diff --git a/plotting/pointmaze/figures/return_avg_ablate.png b/plotting/old/pointmaze/figures/return_avg_ablate.png similarity index 100% rename from plotting/pointmaze/figures/return_avg_ablate.png rename to plotting/old/pointmaze/figures/return_avg_ablate.png diff --git a/plotting/pointmaze/figures/single_task_results.png b/plotting/old/pointmaze/figures/single_task_results.png similarity index 100% rename from plotting/pointmaze/figures/single_task_results.png rename to plotting/old/pointmaze/figures/single_task_results.png diff --git a/plotting/pointmaze/figures/single_task_return.png b/plotting/old/pointmaze/figures/single_task_return.png similarity index 100% rename from plotting/pointmaze/figures/single_task_return.png rename to plotting/old/pointmaze/figures/single_task_return.png diff --git a/plotting/pointmaze/figures/single_task_success.png b/plotting/old/pointmaze/figures/single_task_success.png similarity index 100% rename from plotting/pointmaze/figures/single_task_success.png rename to plotting/old/pointmaze/figures/single_task_success.png diff --git a/plotting/pointmaze/figures/success_rate_and_weights.png b/plotting/old/pointmaze/figures/success_rate_and_weights.png similarity index 100% rename from plotting/pointmaze/figures/success_rate_and_weights.png rename to plotting/old/pointmaze/figures/success_rate_and_weights.png diff --git a/plotting/pointmaze/figures/success_rate_avg_ablate.png b/plotting/old/pointmaze/figures/success_rate_avg_ablate.png similarity index 100% rename from plotting/pointmaze/figures/success_rate_avg_ablate.png rename to plotting/old/pointmaze/figures/success_rate_avg_ablate.png diff --git a/plotting/pointmaze/plot_return_avg_ablate.py b/plotting/old/pointmaze/plot_return_avg_ablate.py similarity index 100% rename from plotting/pointmaze/plot_return_avg_ablate.py rename to plotting/old/pointmaze/plot_return_avg_ablate.py diff --git a/plotting/pointmaze/plot_single_returns.py b/plotting/old/pointmaze/plot_single_returns.py similarity index 100% rename from plotting/pointmaze/plot_single_returns.py rename to plotting/old/pointmaze/plot_single_returns.py diff --git a/plotting/pointmaze/plot_single_success.py b/plotting/old/pointmaze/plot_single_success.py similarity index 100% rename from plotting/pointmaze/plot_single_success.py rename to plotting/old/pointmaze/plot_single_success.py diff --git a/plotting/pointmaze/plot_success_avg_ablate.py b/plotting/old/pointmaze/plot_success_avg_ablate.py similarity index 100% rename from plotting/pointmaze/plot_success_avg_ablate.py rename to plotting/old/pointmaze/plot_success_avg_ablate.py diff --git a/plotting/pointmaze/plot_task_success.py b/plotting/old/pointmaze/plot_task_success.py similarity index 100% rename from plotting/pointmaze/plot_task_success.py rename to plotting/old/pointmaze/plot_task_success.py diff --git a/plotting/pointmaze/plot_update_count.py b/plotting/old/pointmaze/plot_update_count.py similarity index 100% rename from plotting/pointmaze/plot_update_count.py rename to plotting/old/pointmaze/plot_update_count.py diff --git a/plotting/pointmaze/utils.py b/plotting/old/pointmaze/utils.py similarity index 100% rename from plotting/pointmaze/utils.py rename to plotting/old/pointmaze/utils.py diff --git a/plotting/utils.py b/plotting/utils.py index 996ed78..fcaf4e8 100644 --- a/plotting/utils.py +++ b/plotting/utils.py @@ -1,126 +1,19 @@ -# -# def plot(save_dict, name, m=100000, success_threshold=None, return_cutoff=-np.inf): -# i = 0 -# -# # palette = seaborn.color_palette() -# print(os.getcwd()) -# -# for agent, info in save_dict.items(): -# paths = info['paths'] -# x_scale = info['x_scale'] -# max_t = info['max_t'] -# avgs = [] -# for path in paths: -# u, t, avg = load_data(path, name=name, success_threshold=success_threshold) -# if avg is not None: -# if max_t: -# cutoff = np.where(t <= max_t/x_scale)[0] -# avg = avg[cutoff] -# t = t[cutoff] -# -# elif m: -# avg = avg[:m] -# avgs.append(avg) -# t_good = t -# -# if len(avgs) == 0: -# continue -# elif len(avgs) == 1: -# avg_of_avgs = avg -# q05 = np.zeros_like(avg) -# q95 = np.zeros_like(avg) -# -# else: -# -# min_l = np.inf -# for a in avgs: -# l = len(a) -# if l < min_l: -# min_l = l -# -# if min_l < np.inf: -# for i in range(len(avgs)): -# avgs[i] = avgs[i][:min_l] -# -# avg_of_avgs = np.mean(avgs, axis=0) -# -# # if avg_of_avgs.mean() > 0: continue -# # print(np.median(avg_of_avgs)) -# # if np.median(avg_of_avgs) > 0: continue -# -# std = np.std(avgs, axis=0) -# N = len(avgs) -# ci = 1.96 * std / np.sqrt(N) * 1.96 -# q05 = avg_of_avgs - ci -# q95 = avg_of_avgs + ci -# -# -# # if avg_of_avgs[-10:].mean() < 4900 or N < 10: continue -# -# style_kwargs = get_line_styles(agent) -# style_kwargs['linewidth'] = 2 -# -# # style_kwargs['linewidth'] = 1.5 -# -# style_kwargs['color'] = None -# # if 'PROPS' in agent: -# # style_kwargs['linestyle'] = '-' -# # style_kwargs['linewidth'] = 3 -# # # style_kwargs['color'] = 'k' -# # -# # -# # elif 'ppo_buffer' in agent or 'PPO-Buffer' in agent or 'b=' in agent or 'Buffer' in agent: -# # style_kwargs['linestyle'] = '--' -# # elif 'ppo,' in agent or 'PPO,' in agent or 'PPO with' in agent or 'PPO' == agent: -# # style_kwargs['linestyle'] = ':' -# # elif 'Priv' in agent: -# # style_kwargs['linestyle'] = '-.' -# # -# # elif '0.0001' in agent: -# # style_kwargs['linestyle'] = '--' -# -# # print(agent, N, avg_of_avgs[-1], q05[-1], q95[-1]) -# -# try: -# times = info['times'] -# x = times -# except: -# x = t_good * x_scale -# if t is None: -# x = np.arange(len(avg_of_avgs)) -# if m: -# x = x[:m] -# avg_of_avgs = avg_of_avgs[:m] -# q05 = q05[:m] -# q95 = q95[:m] -# plt.plot(x, avg_of_avgs, label=agent, **style_kwargs) -# if style_kwargs['linestyle'] == 'None': -# plt.fill_between(x, q05, q95, alpha=0) -# else: -# plt.fill_between(x, q05, q95, alpha=0.2) -# # plt.fill_between(x, q05, q95, alpha=0.2, color=style_kwargs['color']) -# -# i += 1 -# # return fig - - import os import warnings import numpy as np +from matplotlib import pyplot as plt +import seaborn as sns +from rliable.plot_utils import _annotate_and_decorate_axis + -def get_data(results_dir, x_name='timestep', y_name='returns', lable_name='env_ids', id_name='task_ids', filename='evaluations.npz'): +def get_data(results_dir, x_name='timestep', y_name='return', filename='evaluations.npz'): - print (results_dir) paths = [] try: for subdir in os.listdir(results_dir): - if 'run_' in subdir: - cur_path = f'{results_dir}/{subdir}/{filename}' - if os.path.isfile(cur_path): - paths.append(cur_path) - else: - print (f'No file at {cur_path}!') + if 'run_' in subdir and os.path.exists(f'{results_dir}/{subdir}/{filename}'): + paths.append(f'{results_dir}/{subdir}/{filename}') except Exception as e: print(e) @@ -134,16 +27,94 @@ def get_data(results_dir, x_name='timestep', y_name='returns', lable_name='env_i x = None length = None - ids = None for path in paths: with np.load(path) as data_file: + # for d in data_file: + # print(d) if x is None: x = data_file[x_name] y = data_file[y_name] - y_list.append(y) - z = data_file[lable_name] - ids = data_file[id_name] + if length is None: + length = len(y) + if len(y) == length: + y_list.append(y) + + return x, np.array(y_list) + + + +def plot_sample_efficiency_curve(frames, + point_estimates, + interval_estimates, + algorithms=None, + colors=None, + color_palette='colorblind', + linestyles=None, + figsize=(7, 5), + xlabel=r'Number of Frames (in millions)', + ylabel='Aggregate Human Normalized Score', + ax=None, + labelsize='xx-large', + ticklabelsize='xx-large', + **kwargs): + """Plots an aggregate metric with CIs as a function of environment frames. + + Args: + frames: Array or list containing environment frames to mark on the x-axis. + point_estimates: Dictionary mapping algorithm to a list or array of point + estimates of the metric corresponding to the values in `frames`. + interval_estimates: Dictionary mapping algorithms to interval estimates + corresponding to the `point_estimates`. Typically, consists of stratified + bootstrap CIs. + algorithms: List of methods used for plotting. If None, defaults to all the + keys in `point_estimates`. + colors: Dictionary that maps each algorithm to a color. If None, then this + mapping is created based on `color_palette`. + color_palette: `seaborn.color_palette` object for mapping each method to a + color. + figsize: Size of the figure passed to `matplotlib.subplots`. Only used when + `ax` is None. + xlabel: Label for the x-axis. + ylabel: Label for the y-axis. + ax: `matplotlib.axes` object. + labelsize: Font size of the x-axis label. + ticklabelsize: Font size of the ticks. + **kwargs: Arbitrary keyword arguments. + + Returns: + `axes.Axes` object containing the plot. + """ + if ax is None: + _, ax = plt.subplots(figsize=figsize) + if algorithms is None: + algorithms = list(point_estimates.keys()) + if colors is None: + color_palette = sns.color_palette(color_palette, n_colors=len(algorithms)) + colors = dict(zip(algorithms, color_palette)) + if linestyles is None: + linestyles = dict(zip(algorithms, '-')) - return x, np.array(y_list), z, ids + for algorithm in algorithms: + metric_values = point_estimates[algorithm] + lower, upper = interval_estimates[algorithm] + ax.plot( + frames[algorithm], + metric_values, + color=colors[algorithm], + linestyle=linestyles[algorithm], + marker=kwargs.get('marker', 'o'), + linewidth=kwargs.get('linewidth', 2), + label=algorithm) + ax.fill_between( + frames[algorithm], y1=lower, y2=upper, color=colors[algorithm], alpha=0.2) + kwargs.pop('marker', '0') + kwargs.pop('linewidth', '2') + return _annotate_and_decorate_axis( + ax, + xlabel=xlabel, + ylabel=ylabel, + labelsize=labelsize, + ticklabelsize=ticklabelsize, + **kwargs)