-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreplotting
More file actions
162 lines (128 loc) · 6.09 KB
/
replotting
File metadata and controls
162 lines (128 loc) · 6.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# replotting
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize
import numpy as np
import os
# Import the necessary components from your existing simulation code
from config import CONFIG
from environment import plot_environment
# ==============================================================================
# --- USER CONFIGURATION ---
# ==============================================================================
# 1. EDIT THIS LIST: Add the agent IDs you want to exclude from the plots.
# For example: [1, 5, 8] would exclude agents with those three IDs.
IDS_TO_EXCLUDE = [2,3,23,36,39,59,65,91,72,12,88]
# 2. (Optional) Verify these paths are correct.
# They are automatically read from your main config.py file.
OUTPUT_DIRECTORY = CONFIG.get("animation_output_path", ".")
CSV_FILE_PATH = os.path.join(OUTPUT_DIRECTORY, "full_simulation_trajectories.csv")
# ==============================================================================
# --- SCRIPT LOGIC (No need to edit below this line) ---
# ==============================================================================
def plot_filtered_ground_trajectories(df, config, output_dir):
"""
Plots the trajectories of ground agents (vehicles, pedestrians, bikes) from a DataFrame.
:param df: DataFrame containing the trajectory data for ground agents.
:param config: The simulation configuration dictionary.
:param output_dir: The directory to save the plot.
"""
print("\n--- Plotting filtered ground agent trajectories ---")
fig, ax = plt.subplots(figsize=(14, 14))
plot_environment(config, ax=ax, show_street_names=False)
agent_types = df['agent_type'].unique()
colors = {'vehicle': 'deepskyblue', 'pedestrian': 'fuchsia', 'bike': 'lime'}
# Group by agent ID and plot each trajectory
for agent_id, group in df.groupby('id'):
agent_type = group['agent_type'].iloc[0]
color = colors.get(agent_type, 'gray')
ax.plot(group['raw_x'], group['raw_y'], color=color, linewidth=1.5, alpha=0.8, zorder=8)
ax.set_title("Filtered Ground Agent Trajectories (Vehicles, Pedestrians, and Bikes)")
# Create a legend
from matplotlib.lines import Line2D
legend_elements = []
if 'vehicle' in agent_types:
legend_elements.append(Line2D([0], [0], color=colors['vehicle'], lw=2, label='Vehicle Trajectory'))
if 'pedestrian' in agent_types:
legend_elements.append(Line2D([0], [0], color=colors['pedestrian'], lw=2, label='Pedestrian Trajectory'))
if 'bike' in agent_types:
legend_elements.append(Line2D([0], [0], color=colors['bike'], lw=2, label='Bike Trajectory'))
if legend_elements:
ax.legend(handles=legend_elements, loc='upper right')
filepath = os.path.join(output_dir, "ground_trajectories_filtered.png")
plt.savefig(filepath)
print(f"Saved filtered ground trajectory plot to {filepath}")
plt.close(fig)
def plot_filtered_drone_trajectories(df, config, output_dir):
"""
Plots the trajectories of drone agents with altitude-based coloring from a DataFrame.
:param df: DataFrame containing the trajectory data for drone agents.
:param config: The simulation configuration dictionary.
:param output_dir: The directory to save the plot.
"""
print("\n--- Plotting filtered drone trajectories with altitude coloring ---")
fig, ax = plt.subplots(figsize=(14, 14))
plot_environment(config, ax=ax, show_street_names=False)
if df.empty:
print("No drone trajectories to plot.")
plt.close(fig)
return
min_alt, max_alt = df['raw_z'].min(), df['raw_z'].max()
cmap = plt.get_cmap('viridis')
norm = Normalize(vmin=min_alt, vmax=max_alt)
for agent_id, group in df.groupby('id'):
if len(group) < 2:
continue
points = group[['raw_x', 'raw_y']].to_numpy().reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
z_coords = group['raw_z'].to_numpy()
segment_colors = (z_coords[:-1] + z_coords[1:]) / 2
lc = LineCollection(segments, cmap=cmap, norm=norm)
lc.set_array(segment_colors)
lc.set_linewidth(2)
lc.set_alpha(0.8)
lc.set_zorder(10)
ax.add_collection(lc)
ax.set_title("Filtered Drone Trajectories (Color Mapped to Altitude)")
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax)
cbar.set_label('Altitude (m)')
filepath = os.path.join(output_dir, "drone_trajectories_filtered.png")
plt.savefig(filepath)
print(f"Saved filtered drone trajectory plot to {filepath}")
plt.close(fig)
def main():
"""
Main function to load data, filter based on hardcoded list, and generate plots.
"""
# --- 1. Load Data ---
if not os.path.exists(CSV_FILE_PATH):
print(f"Error: CSV file not found at '{CSV_FILE_PATH}'")
return
print(f"Loading data from: {CSV_FILE_PATH}")
all_trajectories_df = pd.read_csv(CSV_FILE_PATH)
# --- 2. Filter Data ---
print(f"Excluding agent IDs: {IDS_TO_EXCLUDE}")
filtered_df = all_trajectories_df[~all_trajectories_df['id'].isin(IDS_TO_EXCLUDE)].copy()
if filtered_df.empty:
print("No data remains after filtering. No plots will be generated.")
return
# --- 3. Separate Agent Types ---
ground_agent_types = ['vehicle', 'pedestrian', 'bike']
ground_df = filtered_df[filtered_df['agent_type'].isin(ground_agent_types)]
drone_df = filtered_df[filtered_df['agent_type'] == 'drone']
# --- 4. Generate Plots ---
if not os.path.exists(OUTPUT_DIRECTORY):
os.makedirs(OUTPUT_DIRECTORY)
if not ground_df.empty:
plot_filtered_ground_trajectories(ground_df, CONFIG, OUTPUT_DIRECTORY)
else:
print("\nNo ground agent data to plot.")
if not drone_df.empty:
plot_filtered_drone_trajectories(drone_df, CONFIG, OUTPUT_DIRECTORY)
else:
print("\nNo drone agent data to plot.")
if __name__ == "__main__":
main()