forked from snehasinghania/STResNet
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
32 lines (28 loc) · 751 Bytes
/
Copy pathutils.py
File metadata and controls
32 lines (28 loc) · 751 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
'''
Author: Sneha Singhania
This file contains helper functions for running the main neural network.
'''
import numpy as np
def batch_generator(X, y, batch_size):
"""
Batch generator
"""
size = X.shape[0]
X_copy = X.copy()
y_copy = y.copy()
indices = np.arange(size)
np.random.shuffle(indices)
X_copy = X_copy[indices]
y_copy = y_copy[indices]
i = 0
while True:
if i + batch_size <= size:
yield X_copy[i:i + batch_size], y_copy[i:i + batch_size]
i += batch_size
else:
i = 0
indices = np.arange(size)
np.random.shuffle(indices)
X_copy = X_copy[indices]
y_copy = y_copy[indices]
continue