Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions utils/process_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self, image_glob):
self.create_image_thumbs()
self.create_image_vectors()
self.load_image_vectors()
self.build_clustering()
self.write_json()
self.create_atlas_files()
print('Processed output for ' + \
Expand Down Expand Up @@ -151,7 +152,7 @@ def create_image_vectors(self):
self.create_tf_graph()

print(' * creating image vectors')
with tf.Session() as sess:
with tf.compat.v1.Session() as sess:
for image_index, image in enumerate(self.image_files):
try:
print(' * processing image', image_index+1, 'of', len(self.image_files))
Expand All @@ -160,7 +161,7 @@ def create_image_vectors(self):
if os.path.exists(out_path) and not self.rewrite_image_vectors:
continue
# save the penultimate inception tensor/layer of the current image
with tf.gfile.FastGFile(image, 'rb') as f:
with tf.io.gfile.GFile(image, 'rb') as f:
data = {'DecodeJpeg/contents:0': f.read()}
feature_tensor = sess.graph.get_tensor_by_name('pool_3:0')
feature_vector = np.squeeze( sess.run(feature_tensor, data) )
Expand Down Expand Up @@ -199,8 +200,8 @@ def create_tf_graph(self):
'''
print(' * creating tf graph')
graph_path = join(FLAGS.model_dir, 'classify_image_graph_def.pb')
with tf.gfile.FastGFile(graph_path, 'rb') as f:
graph_def = tf.GraphDef()
with tf.io.gfile.GFile(graph_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')

Expand Down Expand Up @@ -253,35 +254,45 @@ def get_image_positions(self, fit_model):
thumb_path = join(self.output_dir, 'thumbs', '32px', img)
with Image.open(thumb_path) as image:
width, height = image.size
# Add the image name, x offset, y offset
cluster = int(self.clustering.labels_[c]) + 1 # Because PixPlot.get_centroids() names them from 1 onwards
# Add the image name, x offset, y offset, cluster
image_positions.append([
os.path.splitext(os.path.basename(img))[0],
int(i[0] * 100),
int(i[1] * 100),
width,
height
height,
cluster
])
return image_positions


def get_centroids(self):
def build_clustering(self):
'''
Use KMeans clustering to find n centroid images
that represent the center of an image cluster
Use KMeans clustering to find n centroids.
'''
print(' * calculating ' + str(self.n_clusters) + ' clusters')
model = KMeans(n_clusters=self.n_clusters)
X = np.array(self.image_vectors)
fit_model = model.fit(X)
centroids = fit_model.cluster_centers_
self.clustering = fit_model


def get_centroids(self):
'''
Find n centroid images that represent the center of an image cluster
'''
centroids = self.clustering.cluster_centers_
labels = self.clustering.labels_
X = np.array(self.image_vectors)
# find the points closest to the cluster centroids
closest, _ = pairwise_distances_argmin_min(centroids, X)
centroid_paths = [self.vector_files[i] for i in closest]
centroid_json = []
for c, i in enumerate(centroid_paths):
centroid_json.append({
'img': get_filename(i),
'label': 'Cluster ' + str(c+1)
'label': 'Cluster ' + str(c+1),
'members': [item for item, label in enumerate(labels) if label == c]
})
return centroid_json

Expand Down Expand Up @@ -480,4 +491,4 @@ def main(*args, **kwargs):
PixPlot(image_glob)

if __name__ == '__main__':
tf.app.run()
tf.compat.v1.app.run()