Image Data Augmentation using KERAS
Discussing the necessity of augmenting the images for CNN to improve the accuracy of our model using different augmenting techniques present in KERAS.
In this blog we will discuss the following:
- The need for data augmentation
- What is data augmentation
- A detailed explanation of data augmentation techniques with Keras
- Complete the code Github link at the end
You might be wondering why I started first with the need for data augmentation rather than its meaning, but that’s the best way to learn anything quickly. So, let's dive into Data Augmentation.
The Need for Data Augmentation
Training DATA is the backbone of an entire Deep Learning project, more the data, more the features that can be extracted, and thus better the accuracy of the model. Deep Learning models are directly dependent on the amount of data, but it’s not always that we have sufficient data to train our images. This problem is best solved by data augmentation.
Data augmentation is an integral process in deep learning, as in deep learning we need large amounts of data and in some cases, it is not feasible to collect thousands or millions of images, so data augmentation comes to the rescue.
What is Data Augmentation?
Data Augmentation is the process of increasing the amount and diversity of data. Instead of collecting new data, we modify present data in such a way that the modification produces almost a new image. This does two jobs at a time:
- Increase the amount of data.
- Help expose our classifier to a wider variety of lighting, colouring, and different orientation situations so as to make our classifier more robust.
- Data Augmentation also helps to reduce the problem of overfitting. The reason is that, as we add more data, the model is unable to overfit all the samples, and is forced to generalize.
Steps to Perform Data Augmentation
ImageDataGenerator
Keras provides us with the class ImageDataGenerator that automatically performs data augmentations. ImageDataGenerator generates batches of image data with real-time data augmentation. The class takes in various arguments to modify the data.
Initial steps to perform Data Augmentation
Import all necessary Libraries
from numpy import expand_dims
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
Loading the images with python
# load the image
img = load_img('sample_image.jpg')#View the image
img.show()# convert to numpy array
data = img_to_array(img)# expand dimension to one sample
samples = expand_dims(data, 0)
Function to show an image after augmentation
# Common function to view imagedef showimage(datagen):
# prepare iterator
it = datagen.flow(samples, batch_size=1)
# generate samples and plot
for i in range(4):
# define subplot
plt.subplot(2, 2, i+1)
# generate batch of images
batch = it.next()
# convert to unsigned integers for viewing
image = batch[0].astype('uint8')
# plot raw pixel data
plt.imshow(image)
# show the figure
plt.show()
1. Vertical Shift
width_shift_range = [-225, 225]
A shift to an image means moving all pixels of the image in one direction, such as horizontally or vertically while keeping the image dimensions the same. This means that some of the pixels will be clipped off the image and there will be a region of the image where new pixel values will have to be specified.
We can perform width shift by passing the values between 0 to 1 i.e. in percentage or by specifying the pixels of shift. Here, I passed a tuple [min, max] values of pixels to define the shift.
By default, the closest pixel value is chosen for the new pixels i.e. the default type is fill_mode = ‘nearest’. But this fill_mode can be changed.
- fill_mode = ‘reflect’ — Mirror image reflection of the image is used to fill new values of the pixels.
- fill_mode = ‘wrap’ — Repetition or copying of the original pixel values takes place.
- fill_mode = ‘constant’ — New pixels are given constant values specified by the user.
2. Vertical Shift
height_shift_range = 0.3
This is similar to the horizontal shift, the only difference is that this shift takes place in the height of the image. Similar arguments like pixels or percentage% can also be passed here. Here, I have shifted the image by 30%.
3. Image Flipping
horizontal_flip = True, vertical_flip = True
An image flip means reversing the rows or columns of pixels in the case of a vertical or horizontal flip respectively. The flip augmentation is specified by a boolean horizontal_flip or vertical_flip argument to the ImageDataGenerator class constructor.
4. Random Rotation
rotation_range = 90
A rotation augmentation randomly rotates the image clockwise by a given number of degrees from 0 to 360. The rotation will likely rotate pixels out of the image frame and leave areas of the frame with no pixel data that must be filled in. Here, I have rotated the image from 0 to 90.
5. Random Brightness
brightness_range = [0.5, 1.0]
The brightness of the image can be augmented by either randomly darkening images, brightening images, or both. Here, we pass values in tuple [min, max]; value < 1 represents a darker image and value>1 generates a brighter image.
5. Random ZOOM
zoom_range = [0.5, 1.0]
A zoom augmentation randomly zooms the image in and either adds new pixel values around the image or interpolates pixel values respectively. Zoom values less than 1.0 will zoom the image in, e.g. [0.5,0.5] makes the object in the image 50% larger or closer, and values larger than 1.0 will zoom the image out by 50%, e.g. [1.5, 1.5] makes the object in the image smaller or further away. A zoom of [1.0,1.0] has no effect.
6. Channel Shift:
channel_shift_range = 150
Channel shift randomly shifts the channel values by a random value chosen from the range specified by channel_shift_range
.
The final code
# Combining all techniques# create image data augmentation generatordatagen = ImageDataGenerator(width_shift_range=[-225,225], height_shift_range = 0.3, horizontal_flip=True, vertical_flip=True, rotation_range=90, zoom_range=[0.5,1.0], brightness_range=[0.5,1.0], channel_shift_range = 150)# prepare iterator
it = datagen.flow(samples, batch_size=1)
# generate samples and plot
for i in range(9):
# define subplot
plt.subplot(3, 3, i+1)
# generate batch of images
batch = it.next()
# convert to unsigned integers for viewing
image = batch[0].astype('uint8')
# plot raw pixel data
plt.imshow(image)
# show the figure
plt.show()
GitHub Link: Data Augmentation
Thank You, for reading the blog. Hope you enjoyed it, have a great day!