-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
72 lines (52 loc) · 2 KB
/
Copy pathutils.py
File metadata and controls
72 lines (52 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import pandas as pd
from PIL import Image
import torch
# MTCNN
from torch.utils.data import (
Dataset,
DataLoader,
) # Gives easier dataset managment and creates mini batches
class PeopleDataset(Dataset):
def __init__(self, csv_file, root_dir, default_label, default_img ,transform=None):
self.annotations = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
self.default_label = default_label
self.default_img = default_img
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
img_path = os.path.join(self.root_dir,self.annotations.iloc[index, 0])
image = Image.open(img_path)
img_cropped = mtcnn(image)
y_label = self.annotations.iloc[index,1]
# if self.transform:
# image = self.transform(image)
if(img_cropped == None):
return self.default_img,self.default_label
return img_cropped, y_label
from facenet_pytorch import MTCNN
mtcnn = MTCNN(
image_size=160, margin=0, min_face_size=20,
thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
device="cpu"
)
root_dir ="CASIA_dataset/Images"
img = Image.open(os.path.join(root_dir,'0.png'))
label_default = 4250
# Get cropped and prewhitened image tensor
img_cropped_default = mtcnn(img)
# Load Data
dataset = PeopleDataset(
csv_file="CASIA_dataset/annotations.csv",
root_dir=root_dir,
default_label=label_default,
default_img=img_cropped_default
)
# DATA LOADER
def create_loader(batch_size , num_workers, fraction=1):
# creater a data loader on the given fraction of dataset (default = 1)
train_set , _ = torch.utils.data.random_split(dataset, [int(len(dataset)*fraction), len(dataset)-int(len(dataset)*fraction )])
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
return train_loader