-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtask_distribution_updates.py
More file actions
148 lines (117 loc) · 3.63 KB
/
task_distribution_updates.py
File metadata and controls
148 lines (117 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Optional
import numpy as np
def kl_project_with_floor(z: np.ndarray, eps: float) -> np.ndarray:
"""
KL projection of distribution z onto:
{ q : q_i >= eps, sum_i q_i = 1 }
Solves: min_q KL(q || z) subject to q_i >= dro_eps.
"""
k = len(z)
free = np.ones(k, dtype=bool)
q = np.zeros_like(z)
while True:
mass_free = 1.0 - eps * (~free).sum()
if mass_free < 0:
return np.ones(k) / k
z_free_sum = z[free].sum()
if z_free_sum == 0:
q[free] = mass_free / free.sum()
else:
q[free] = (mass_free / z_free_sum) * z[free]
q[~free] = eps
violated = free & (q < eps - 1e-12)
if not violated.any():
break
free[violated] = False
return q / q.sum()
def mirror_ascent_kl_update(
q: np.ndarray,
gap: np.ndarray,
eta: float,
step_size: float,
eps: Optional[float] = None,
p0: Optional[np.ndarray] = None,
) -> np.ndarray:
gap = np.asarray(gap, dtype=np.float64)
k = len(q)
if p0 is None:
p0 = np.ones(k, dtype=np.float64) / k
else:
p0 = np.asarray(p0, dtype=np.float64)
p0 = p0 / p0.sum()
alpha = step_size # now step_size is directly in [0, 1]
log_q_star = np.log(p0) + eta * gap
log_q_new = alpha * log_q_star + (1.0 - alpha) * np.log(q)
log_q_new -= np.max(log_q_new)
q_new = np.exp(log_q_new)
q_new /= q_new.sum()
if eps is not None and eps > 0:
q_new = kl_project_with_floor(q_new, eps)
return q_new
def learning_progress_update(
q: np.ndarray,
learning_progress: np.ndarray,
success_rates: np.ndarray,
eta: float,
step_size: float,
eps: Optional[float] = None,
p0: Optional[np.ndarray] = None,
success_threshold: float = 0.9,
) -> np.ndarray:
q_new = mirror_ascent_kl_update(
q=q,
gap=learning_progress,
eta=eta,
step_size=step_size,
eps=None, # no need to do kl projection here since we do it later
p0=p0, # uniform by default
)
solved = np.where(success_rates > success_threshold)[0]
q_new[solved] = eps
q_new = kl_project_with_floor(q_new, eps=eps)
return q_new
def easy_first_curriculum_update(
success_rates: np.ndarray,
eps: float,
success_threshold: float = 0.9,
) -> np.ndarray:
k = len(success_rates)
# leftmost unsolved task
unsolved = np.where(success_rates < success_threshold)[0]
if len(unsolved) == 0:
q = np.ones(k) / k
return q / q.sum()
active = int(unsolved[0])
q = np.full(k, eps)
q[active] = 1.0 - eps * (k - 1)
return q / q.sum()
def hard_first_curriculum_update(
success_rates: np.ndarray,
eps: float,
success_threshold: float = 0.9,
) -> np.ndarray:
k = len(success_rates)
# rightmost unsolved task
unsolved = np.where(success_rates < success_threshold)[0]
if len(unsolved) == 0:
q = np.ones(k) / k
return q / q.sum()
active = int(unsolved[-1])
q = np.full(k, eps)
q[active] = 1.0 - eps * (k - 1)
return q / q.sum()
def exponentiated_gradient_ascent_step(
args,
w: np.ndarray,
returns: np.ndarray,
returns_ref: np.ndarray,
previous_return_avg: np.ndarray,
learning_rate: float = 1.0,
eps: float = 0.1,
) -> np.ndarray:
diff = np.clip(returns_ref - returns, 0, np.inf)
w_new = w * np.exp(learning_rate * diff)
w_new = w_new / w_new.sum()
w_uniform = np.ones_like(w_new) / len(w_new)
w_new = (1 - eps) * w_new + eps * w_uniform
return w_new