diff --git a/plantcv/learn/train_kmeans.py b/plantcv/learn/train_kmeans.py index e5ee33e16..b27ff4017 100644 --- a/plantcv/learn/train_kmeans.py +++ b/plantcv/learn/train_kmeans.py @@ -83,6 +83,9 @@ def patch_extract(img, patch_size=10, sigma=5, sampling=None, seed=1): img_blur = np.round(gaussian(img, sigma=sigma)*255).astype(np.uint16) elif len(img.shape) == 3 and img.shape[2] == 3: img_blur = np.round(gaussian(img, sigma=sigma, channel_axis=2)*255).astype(np.uint16) + elif len(img.shape) == 3 and img.shape[2] == 4: # rgb with alpha + img = img[:, :, :3] # removes alpha channel + img_blur = np.round(gaussian(img, sigma=sigma, channel_axis=2)*255).astype(np.uint16) # Extract patches patches = image.extract_patches_2d(img_blur, (patch_size, patch_size),