Simple Image Classification With CNN Using Tensorflow For Beginners

栏目: IT技术 · 发布时间: 3年前

内容简介:Learn to perform a simple image classification task by doing a project that will use a Convolutional Neural Network.Image classification is not a hard topic anymore. Tensorflow has all the inbuilt functionalities that take care of the complex mathematics f

Simple Image Classification With CNN Using Tensorflow For Beginners

Learn to perform a simple image classification task by doing a project that will use a Convolutional Neural Network.

Source: Unsplash by Tran Mau Tri Tam

Image classification is not a hard topic anymore. Tensorflow has all the inbuilt functionalities that take care of the complex mathematics for us. Without knowing the details of the neural network, we can use a neural network now. In today’s project, I used a Convolutional Neural Network (CNN) which is an advanced version of the neural network. It condenses down a picture to some important features. If you worked with the FashionMNIST dataset that contains shirts, shoe handbags, etc., CNN will figure out important portions of the images. For example, if you see a shoelace, it might be a shoe, if there are a collar and buttons, that might be a shirt or if there is a handle, that might be a handbag.

Overview

The simple CNN we will build today to classify a set of images will consist of convolutions and pooling. Inputs get to modify in the convolution layers. You can put one or more convolutions depending on your requirement. Inputs go through several filters and those filters slice through the inputs to learn portions of an input such as the buttons of shirts, the handle of a handbag, or lace of a shoe. I am not going too deeper on it today. Because this article is for beginners.

Pooling is another very important part of CNN. Pooling works on each local region like convolutions but they do not have filters and it is a vector to scalar transformation. The simply compute the average of the region and recognize the pixels with the highest intensity and eliminate the rest. A 2 x 2 pooling will reduce the size of feature maps by a factor of 2. Even if you don’t know the mathematical part of it, you can still solve a deep learning problem. I will explain every line of code for that. Nowadays we have such rich libraries to perform all this amazing work without even knowing much math or coding. Let’s dive in.

CNN Development

I used a Google Colab notebook. If you don’t have anaconda and Jupiter notebook installed you can still work on it. Google’s collaboratory notebook is available to everyone. There are lots of youtube videos that are there to learn how to use Google Colab. Please feel free to check those out if Google Colab is not known to you. We will use a dataset that contains the images of cats and dogs. Our goal is to develop a convolutional neural network that will successfully classify cats and dogs from a picture. We are using the dataset from Kaggle.

First import all the required packages and libraries.

import osimport zipfileimport randomimport tensorflow as tffrom tensorflow.keras.optimizers import RMSpropfrom tensorflow.keras.preprocessing.image import ImageDataGeneratorimport shutil

It’s time to get our dataset. We will use a function named ‘wget’ to bring the dataset in the notebook. Just a reminder, once your Google Collab notebook’s session is over, you have to import the dataset again. Let’s download the full Cats-v-Dogs dataset and store it as cats-and-dogs.zip and save it in a directory named ‘tmp’.

!wget –no-check-certificate \    "https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip" \    -O "/tmp/cats-and-dogs.zip"

Now extract the data from the zip folder which will generate a directory named ‘temp/PetImages’ with two subdirectories called Cat and Dog. That’s how the data was originally structured.

local_zip = '/tmp/cats-and-dogs.zip'zip_ref = zipfile.ZipFile(local_zip, 'r')zip_ref.extractall('/tmp')zip_ref.close()

Lets’s check the Cat and Dog folders.

print(len(os.listdir('/tmp/PetImages/Cat/')))print(len(os.listdir('/tmp/PetImages/Dog/'))

As the data is available to use, now we need to create a directory named cats-v-dogs and subdirectories training and testing.

try:    os.mkdir('/tmp/cats-v-dogs/')    os.mkdir('/tmp/cats-v-dogs/training/')    os.mkdir('/tmp/cats-v-dogs/testing/')except OSError:    pass

Now, split the data for training and testing, place the data in the correct directory with a function split_data. Split_data takes a SOURCE directory containing the files, a TRAINING directory where a slice of the data will be copied to, a TESTING directory where the remaining data will be copied to and a split_size to slice the data.

def split_data(SOURCE, TRAINING, TESTING, SPLIT_SIZE):    cont = os.listdir(SOURCE)     lenList = len(cont)     shuffleList = random.sample(cont, lenList)      slicePoint = round(len(shuffleList)*SPLIT_SIZE)  for i in range(0, len(shuffleList[:slicePoint])):    if os.path.getsize(SOURCE+cont[i]) !=0:   
        shutil.copy(os.path.join(SOURCE,cont[i]), training) 

The code block below checks the remaining files for length and put them in the TESTING directory.

for j in range(len(shuffleList[slicePoint:])):    if os.path.getsize(SOURCE+cont[j]) !=0:        shutil.copy(os.path.join(SOURCE,cont[j]), testing)

The function is ready. Use the split_data function to split the data of the source directory and copy them over to the training and testing directory.

CAT_SOURCE_DIR = "/tmp/PetImages/Cat/"TRAINING_CATS_DIR = "/tmp/cats-v-dogs/training/cats/"TESTING_CATS_DIR = "/tmp/cats-v-dogs/testing/cats/"DOG_SOURCE_DIR = "/tmp/PetImages/Dog/"TRAINING_DOGS_DIR = "/tmp/cats-v-dogs/training/dogs/"TESTING_DOGS_DIR = "/tmp/cats-v-dogs/testing/dogs/"split_size = .9split_data(CAT_SOURCE_DIR, TRAINING_CATS_DIR, TESTING_CATS_DIR, split_size)split_data(DOG_SOURCE_DIR, TRAINING_DOGS_DIR, TESTING_DOGS_DIR, split_size)

check the length of the training and testing directory.

print(len(os.listdir('/tmp/cats-v-dogs/training/cats/')))print(len(os.listdir('/tmp/cats-v-dogs/training/dogs/')))print(len(os.listdir('/tmp/cats-v-dogs/testing/cats/'))print(len(os.listdir('/tmp/cats-v-dogs/testing/dogs/')))

Data preprocessing is done. Here comes the fun part. We will develop a Keras model to classify the cats and dogs. In this model, we will use three convolutional layers and a pooling layer. You can try it with less or more convolution layers. We will use an activation function and input_shape 150 x 150. This input_shape will reshape all the images into this same square shape. Otherwise, images in the real-world will come in different sizes and shapes. In the first layer, we have filter size is 3 x 3, and the number of filters is 16. Max pooling 2 x 2 will condense the pixels by the factor of 2. We have two more layers with different numbers of filters. You can add extra ‘Conv2D’ and ‘MaxPooling2D’ layers to observe the results.

model = tf.keras.models.Sequential([  

    tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(150, 150, 3)),    tf.keras.layers.MaxPooling2D(2,2),    tf.keras.layers.Conv2D(32, (3,3), activation='relu'),    tf.keras.layers.MaxPooling2D(2,2),    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),    tf.keras.layers.MaxPooling2D(2,2),        tf.keras.layers.Flatten(),         tf.keras.layers.Dense(512, activation='relu'),         tf.keras.layers.Dense(1, activation='sigmoid')model.compile(optimizer=RMSprop(lr=0.001), loss='binary_crossentropy', metrics=['acc'])

In the compile function, we should pass at least optimizer and loss parameters. Here learning rate is 0.001. It is important to choose a reasonable learning rate. The too small and too big learning rate can make the network inefficient. The next step is to normalize the data.

from tensorflow.keras.preprocessing.image import ImageDataGeneratorbase_dir = '/tmp/cats-v-dogs'TRAINING_DIR = os.path.join(base_dir, 'training')train_datagen = ImageDataGenerator(rescale = 1.0/255)train_generator = train_datagen.flow_from_directory(TRAINING_DIR,                                        batch_size=20,   class_mode='binary',  target_size=(150, 150))

ImageDataGenerator helps to normalize the pixels’ values and make them in between 0 and 1. Originally the values can be 0 to 255 as you may already know. Then we pass our data in batches for training. Here we are providing batch_size 20. We need to normalize the testing data in the same way:

ALIDATION_DIR =os.path.join(base_dir, 'testing')validation_datagen = ImageDataGenerator(rescale = 1.0/255)validation_generator =  validation_datagen.flow_from_directory(VALIDATION_DIR,                                          batch_size=20,  class_mode='binary',  target_size=(150, 150))

Now train the model. Let’s train it with 15 epochs. Please feel free to test with more or fewer epochs. You should keep track of 4 parameters. Loss, accuracy, validation loss, and validation accuracy. The loss should go down and accuracy should go up with every epoch.

history = model.fit_generator(train_generator,  epochs=15,  verbose=1,                              validation_data=validation_generator)

I got 89.51% accuracy in training set and 91.76% accuracy on validation data. I have to mention one thing here. That is, if accuracy on the training set is very high and accuracy in test set or validation set is not that good, that is an overfitting problem. It means model learned training dataset so well that it only knows that training data very well it’s not good for other unseen data. But that’s not our goal. Our goal is to develop a model that is good for the overall most dataset out there. When you see overfitting, you need to modify the training parameter. Probably less number of epochs, different learning rate. We will talk about how to deal with overfitting in a later article.


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

JavaScript凌厉开发

JavaScript凌厉开发

张鑫 黄灯桥、杨彦强 / 清华大学出版社 / 2010 年4月 / 49.00元

本书详细介绍Ext JS框架体系结构,以及利用HTML/CSS/JavaScript进行前端设计的方法和技巧。作者为Ext中文站站长领衔的三个国内Ext JS先锋,在开发思维和开发经验上有着无可争议的功力。 本书包含的内容有Ext.Element.*、事件Observable、Ext组件+MVC原理、Grid/Form/Tree/ComboBox、Ajax缓存Store等,并照顾JavaSc......一起来看看 《JavaScript凌厉开发》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

HTML 编码/解码
HTML 编码/解码

HTML 编码/解码

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具