-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
63 lines (50 loc) · 1.98 KB
/
dataset.py
File metadata and controls
63 lines (50 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from pathlib import Path
import cv2
import csv
from tqdm import tqdm
class MirrorGazeDataset(torch.utils.data.Dataset):
def __init__(self, pids):
self.pids = pids
img_paths = []
gts = []
path_to_crops = Path("../data/crops/")
path_to_collected = Path("../data/collected/")
for pid in tqdm(self.pids):
pid_dir = path_to_collected / f"p{pid}"
if not pid_dir.exists():
continue
# Find all recordings for this participant
for log_path in sorted(pid_dir.glob("log_*.csv")):
rec_num = log_path.stem.replace("log_", "")
crop_dir = path_to_crops / f"p{pid}_{rec_num}"
if not crop_dir.exists():
continue
# Load log CSV to get ground truth
log_rows = {}
with open(log_path, "r") as f:
reader = csv.DictReader(f)
for row in reader:
log_rows[int(row["frame_id"])] = (
float(row["dot_x"]),
float(row["dot_y"]),
)
for img_path in sorted(crop_dir.glob("*.png")):
frame_id = int(img_path.stem)
if frame_id not in log_rows:
continue
dot_x, dot_y = log_rows[frame_id]
gts.append([dot_x / 806.0, dot_y / 1194.0])
img_paths.append(str(img_path))
print(f"Loaded dataset: {len(img_paths)} samples")
self.img_paths = img_paths
self.gts = gts
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
img = cv2.imread(self.img_paths[idx])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
gt = self.gts[idx]
img = torch.tensor(img).float().permute(2, 0, 1) / 255.0
gt = torch.tensor(gt).float()
return img, gt