-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaugment_data.py
More file actions
137 lines (110 loc) · 4.08 KB
/
augment_data.py
File metadata and controls
137 lines (110 loc) · 4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
EECS 445 - Introduction to Machine Learning
Winter 2024 - Project 2
Script to create an augmented dataset.
"""
import argparse
import csv
import glob
import os
import sys
import numpy as np
from scipy.ndimage import rotate
from imageio.v3 import imread, imwrite
import rng_control
def Rotate(deg=30):
"""Return function to rotate image."""
def _rotate(img):
"""Rotate a random integer amount in the range [-deg, deg] (inclusive).
Keep the dimensions the same and fill any missing pixels with black.
Note: We have imported a helpful function from scipy.ndimage above.
:img: H x W x C numpy array
:returns: H x W x C numpy array
"""
# TODO: implement _rotate(img)
degree = np.random.uniform(-deg, deg)
return rotate(img, degree, reshape=False)
return _rotate
def Grayscale():
"""Return function to grayscale image."""
def _grayscale(img):
"""Return 3-channel grayscale of image.
Compute grayscale values by taking average across the three channels.
Round to the nearest integer.
:img: H x W x C numpy array
:returns: H x W x C numpy array
"""
# TODO: implement _grayscale(img)
grayscale_img = np.mean(img, axis=2, keepdims=True).astype(np.uint8)
return np.repeat(grayscale_img, 3, axis=2)
return _grayscale
def augment(filename, transforms, n=1, original=True):
"""Augment image at filename.
:filename: name of image to be augmented
:transforms: List of image transformations
:n: number of augmented images to save
:original: whether to include the original images in the augmented dataset or not
:returns: a list of augmented images, where the first image is the original
"""
print(f"Augmenting {filename}")
img = imread(filename)
res = [img] if original else []
for i in range(n):
new = img
for transform in transforms:
new = transform(new)
res.append(new)
return res
def main(args):
"""Create augmented dataset."""
reader = csv.DictReader(open(args.input, "r"), delimiter=",")
writer = csv.DictWriter(
open(f"{args.datadir}/augmented_landmarks.csv", "w"),
fieldnames=["filename", "semantic_label", "partition", "numeric_label", "task"],
)
augment_partitions = set(args.partitions)
# TODO: change `augmentations` to specify which augmentations to apply
augmentations = [Grayscale()]
writer.writeheader()
os.makedirs(f"{args.datadir}/augmented/", exist_ok=True)
for f in glob.glob(f"{args.datadir}/augmented/*"):
print(f"Deleting {f}")
os.remove(f)
for row in reader:
if row["partition"] not in augment_partitions:
imwrite(
f"{args.datadir}/augmented/{row['filename']}",
imread(f"{args.datadir}/images/{row['filename']}"),
)
writer.writerow(row)
continue
imgs = augment(
f"{args.datadir}/images/{row['filename']}",
augmentations,
n=1,
original=False, # TODO: change to False to exclude original image.
)
for i, img in enumerate(imgs):
fname = f"{row['filename'][:-4]}_aug_{i}.png"
imwrite(f"{args.datadir}/augmented/{fname}", img)
writer.writerow(
{
"filename": fname,
"semantic_label": row["semantic_label"],
"partition": row["partition"],
"numeric_label": row["numeric_label"],
"task": row["task"],
}
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input", help="Path to input CSV file")
parser.add_argument("datadir", help="Data directory", default="./data/")
parser.add_argument(
"-p",
"--partitions",
nargs="+",
help="Partitions (train|val|test|challenge|none)+ to apply augmentations to. Defaults to train",
default=["train"],
)
main(parser.parse_args(sys.argv[1:]))