-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrender_partial.py
More file actions
114 lines (93 loc) · 4.01 KB
/
Copy pathrender_partial.py
File metadata and controls
114 lines (93 loc) · 4.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import math
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.animation import FuncAnimation, FFMpegWriter
from DQN import QNetwork
from call_cart_pole_demo import CartPoleEnv
MODEL_PATH = "dqn_model_partial_with_x_theta_theta_dot.pth"
OUTPUT_NAME = "cartpole_agent_partial" # saves as .mp4 or .gif
TITLE = "DQN Swing-Up Agent — Simulink CartPole"
# set OBS_INDICES to match whatever the model was trained with
# [0, 1, 2, 3] = full state
# [0, 2, 3] = x, theta, theta_dot
# [2, 3] = theta, theta_dot only
OBS_INDICES = [0, 2, 3]
OBS_DIM = len(OBS_INDICES)
device = torch.device("cpu")
policy_net = QNetwork(state_dim=OBS_DIM).to(device)
policy_net.load_state_dict(torch.load(MODEL_PATH, map_location=device))
policy_net.eval()
env = CartPoleEnv(swing_up=False)
state = env.reset()
states, rewards = [], []
for _ in range(500):
states.append(state.copy()) # always store full state for animation
obs = state[OBS_INDICES] # filter to what the agent sees
with torch.no_grad():
action_idx = policy_net(torch.FloatTensor(obs).unsqueeze(0)).argmax().item()
state, reward, done = env.step(action_idx)
rewards.append(reward)
if done:
break
print(f"Episode length: {len(states)} steps, Total reward: {sum(rewards):.1f}")
CART_W, CART_H = 0.4, 0.2
POLE_LEN = 1.0
TRACK_LIM = 2.5
fig, (ax, ax2) = plt.subplots(2, 1, figsize=(8, 7),
gridspec_kw={'height_ratios': [3, 1]})
fig.patch.set_facecolor('#1a1a2e')
for a in [ax, ax2]:
a.set_facecolor('#16213e')
ax.set_xlim(-TRACK_LIM - 0.5, TRACK_LIM + 0.5)
ax.set_ylim(-0.5, 1.8)
ax.set_aspect('equal')
ax.axhline(y=0, color='#e94560', linewidth=2, alpha=0.5)
ax.axvline(x=-TRACK_LIM, color='#e94560', linewidth=1, linestyle='--', alpha=0.4)
ax.axvline(x= TRACK_LIM, color='#e94560', linewidth=1, linestyle='--', alpha=0.4)
ax.set_title(TITLE, color='white', fontsize=13)
ax.tick_params(colors='white')
cart = patches.Rectangle((0, 0), CART_W, CART_H,
facecolor='#0f3460', edgecolor='#e94560', linewidth=2)
ax.add_patch(cart)
pole_line, = ax.plot([], [], color='#f5a623', linewidth=5, solid_capstyle='round')
pivot_dot, = ax.plot([], [], 'o', color='white', markersize=6, zorder=5)
tip_dot, = ax.plot([], [], 'o', color='#f5a623', markersize=8, zorder=5)
step_text = ax.text(0.02, 0.95, '', transform=ax.transAxes,
color='white', fontsize=10, va='top')
angle_text = ax.text(0.98, 0.95, '', transform=ax.transAxes,
color='#f5a623', fontsize=10, va='top', ha='right')
ax2.set_xlim(0, len(states))
ax2.set_ylim(-1.5, 1.5)
ax2.set_ylabel('Reward', color='white', fontsize=9)
ax2.set_xlabel('Step', color='white', fontsize=9)
ax2.tick_params(colors='white')
ax2.axhline(y=0, color='gray', linewidth=0.5)
reward_line, = ax2.plot([], [], color='#f5a623', linewidth=1.5)
reward_history = []
def animate(i):
x, _, theta, _ = states[i] # always unpack full state for rendering
cart.set_x(x - CART_W / 2)
cart.set_y(-CART_H / 2)
px = x + POLE_LEN * math.sin(theta)
py = POLE_LEN * math.cos(theta)
pole_line.set_data([x, px], [0, py])
pivot_dot.set_data([x], [0])
tip_dot.set_data([px], [py])
theta_deg = math.degrees(((theta + math.pi) % (2 * math.pi)) - math.pi)
step_text.set_text(f'Step: {i}')
angle_text.set_text(f'θ = {theta_deg:.1f}°')
reward_history.append(rewards[i])
reward_line.set_data(range(len(reward_history)), reward_history)
return cart, pole_line, pivot_dot, tip_dot, step_text, angle_text, reward_line
ani = FuncAnimation(fig, animate, frames=len(states), interval=30, blit=True)
plt.tight_layout()
try:
writer = FFMpegWriter(fps=30, bitrate=1800)
ani.save(f"{OUTPUT_NAME}.mp4", writer=writer)
print(f"Saved {OUTPUT_NAME}.mp4")
except Exception:
ani.save(f"{OUTPUT_NAME}.gif", writer='pillow', fps=30)
print(f"Saved {OUTPUT_NAME}.gif (install ffmpeg for mp4)")
plt.show()