diff --git a/utils/process_images.py b/utils/process_images.py index 626d6d45..37fa8e2b 100755 --- a/utils/process_images.py +++ b/utils/process_images.py @@ -68,6 +68,7 @@ def __init__(self, image_glob): self.create_image_thumbs() self.create_image_vectors() self.load_image_vectors() + self.build_clustering() self.write_json() self.create_atlas_files() print('Processed output for ' + \ @@ -151,7 +152,7 @@ def create_image_vectors(self): self.create_tf_graph() print(' * creating image vectors') - with tf.Session() as sess: + with tf.compat.v1.Session() as sess: for image_index, image in enumerate(self.image_files): try: print(' * processing image', image_index+1, 'of', len(self.image_files)) @@ -160,7 +161,7 @@ def create_image_vectors(self): if os.path.exists(out_path) and not self.rewrite_image_vectors: continue # save the penultimate inception tensor/layer of the current image - with tf.gfile.FastGFile(image, 'rb') as f: + with tf.io.gfile.GFile(image, 'rb') as f: data = {'DecodeJpeg/contents:0': f.read()} feature_tensor = sess.graph.get_tensor_by_name('pool_3:0') feature_vector = np.squeeze( sess.run(feature_tensor, data) ) @@ -199,8 +200,8 @@ def create_tf_graph(self): ''' print(' * creating tf graph') graph_path = join(FLAGS.model_dir, 'classify_image_graph_def.pb') - with tf.gfile.FastGFile(graph_path, 'rb') as f: - graph_def = tf.GraphDef() + with tf.io.gfile.GFile(graph_path, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(graph_def, name='') @@ -253,27 +254,36 @@ def get_image_positions(self, fit_model): thumb_path = join(self.output_dir, 'thumbs', '32px', img) with Image.open(thumb_path) as image: width, height = image.size - # Add the image name, x offset, y offset + cluster = int(self.clustering.labels_[c]) + 1 # Because PixPlot.get_centroids() names them from 1 onwards + # Add the image name, x offset, y offset, cluster image_positions.append([ os.path.splitext(os.path.basename(img))[0], int(i[0] * 100), int(i[1] * 100), width, - height + height, + cluster ]) return image_positions - - def get_centroids(self): + def build_clustering(self): ''' - Use KMeans clustering to find n centroid images - that represent the center of an image cluster + Use KMeans clustering to find n centroids. ''' print(' * calculating ' + str(self.n_clusters) + ' clusters') model = KMeans(n_clusters=self.n_clusters) X = np.array(self.image_vectors) fit_model = model.fit(X) - centroids = fit_model.cluster_centers_ + self.clustering = fit_model + + + def get_centroids(self): + ''' + Find n centroid images that represent the center of an image cluster + ''' + centroids = self.clustering.cluster_centers_ + labels = self.clustering.labels_ + X = np.array(self.image_vectors) # find the points closest to the cluster centroids closest, _ = pairwise_distances_argmin_min(centroids, X) centroid_paths = [self.vector_files[i] for i in closest] @@ -281,7 +291,8 @@ def get_centroids(self): for c, i in enumerate(centroid_paths): centroid_json.append({ 'img': get_filename(i), - 'label': 'Cluster ' + str(c+1) + 'label': 'Cluster ' + str(c+1), + 'members': [item for item, label in enumerate(labels) if label == c] }) return centroid_json @@ -480,4 +491,4 @@ def main(*args, **kwargs): PixPlot(image_glob) if __name__ == '__main__': - tf.app.run() + tf.compat.v1.app.run()