-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFrameDataset.py
More file actions
56 lines (45 loc) · 2.19 KB
/
FrameDataset.py
File metadata and controls
56 lines (45 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from einops import repeat
class FrameDataset(Dataset):
def __init__(self, frames, stats, labels, train=True, seqLen=20, numSplits=10, splitIdxs = [], doublePct=0.0, stopPct=0.0):
self.normalize = transforms.Normalize(stats[0], stats[1], inplace=True)
self.augment = transforms.RandomHorizontalFlip()
splitIdxs = torch.sort(splitIdxs)[0]
self.doublePct = doublePct
self.stopPct = stopPct
self.train = train
self.seqLen = seqLen
self.labels = labels
self.frames = frames
assert len(self.frames) == len(self.labels)
if len(self.frames) % numSplits != 0:
numSplits+=1
self.splitSize = len(self.frames) // numSplits
self.splits = torch.zeros(len(splitIdxs), dtype=torch.int32)
for i in range(len(splitIdxs)):
self.splits[i] = (splitIdxs[i] - i) * self.splitSize
#print(self.train, self.splits)
def __len__(self):
return len(self.splits) * self.splitSize - self.seqLen + 1
def __getitem__(self, idx):
assert not torch.is_tensor(idx)
idx += self.splits[idx // self.splitSize]
if self.train:
rand = torch.rand(1).item()
if rand < self.stopPct:
frames = self.normalize(self.frames[idx] / 255)
frames = repeat(frames, 'c h w -> k c h w', k=self.seqLen)
label = torch.tensor(0.0)
elif rand < self.stopPct + self.doublePct and idx + self.seqLen * 2 - 1 <= len(self.frames):
frames = self.normalize(self.frames[idx : idx + self.seqLen * 2 : 2] / 255)
label = self.labels[idx + self.seqLen - 1] * 2
else:
frames = self.normalize(self.frames[idx : idx + self.seqLen] / 255)
label = self.labels[idx + self.seqLen // 2]
frames = self.augment(frames)
else:
frames = self.normalize(self.frames[idx : idx + self.seqLen] / 255)
label = self.labels[idx + self.seqLen // 2]
return frames.permute(1,0,2,3), label