We have been in situations during Model Training, when the Validation Accuracy will not increase beyond a certain point or will increase frustratingly slowly. In some cases the Validation Curve will fluctuate up and down. I am not talking about the small noise but the crazy fluctuations greater than +/- 5%. When this happens we unleash our quest in all different directions. Can it be the Learning Rate? Is the batch size too big or small? Is the model not generalizing enough?
However in several cases a bad dataset is the culprit. In this example, I am going to discuss one such case study that I did while debugging abnormal training and validation curve behavior. To set up the expectations in advance, the goal of this article is not to go over the coding aspect of how model was trained or how to write a customized CNN. There are several posts that talk about this: Create CNN by Tensorflow.org and Building Custom CNN. You also need to create your own custom dataset generators. This blog will be helpful in writing your own custom dataset generator
So what was the Issue?
Before getting into the issue, let’s visit the building blocks of our Deep Learning Framework, CNN architecture, datasets, model training, etc.
The CNN Architecture
I used a CNN to classify FontStyle, given an image with text. The design of the CNN was motivated by the Deep Font Paper. Several modifications were made to the network architecture per our problem requirements. My original project had 24 different font style classes but for this case study, I will use three to five different classes. A high level schema of the neural network architecture is shown below:
The original dataset had 24 classes (font_styles). Each class had ~ 5000 and ~1000 samples in the training and validation set respectively. Data for each class is synthesized by writing fake text on a 100 by 100 white or noisy background patch using the font style of that class. The noisy background is generated by using cropped blank areas from scanned documents. Additional augmentation is applied to this background. We will discuss more about augmentation in the next section. Figures below show examples of training images generated using Open Source Google Fonts. From Left to Right: Arimo-Regular, Courier-Prime-Regular, Lato-Light and Tinos-Italic.
Standard data augmentation methods were applied on top of the synthetically generated data discussed in previous section. The following augmentation was used:
- Affine Rotation
- Variation in Text Color
- Variation in Font Size
Training was conducted for 200 epochs. The batch size was 64. I chose a conservative learning rate of 1e-05. No early stopping was used initially. I used Stochastic Gradient Descent (SGD) optimizer with a momentum of 0.9 and decay rate of 1e-6. Since this a classification problem with labels in one hot encoded form, a Categorical Cross Entropy Loss was chosen.
When the training started, The loss decreased for both training and validation. And then suddenly, at around 50th epoch (50-55% validation accuracy), validation loss started to increase while training loss continued to decrease.
What was the issue?
When we see this behavior, the first reaction is “Model is Overfitting” or model unable to generalize well. This makes sense since the model will learn whatever you give it, so training loss will decrease but not perform well on Validation curve.
What is the solution?
Once we conclude the issue is ‘Overfitting’, it’s common to perform data augmentation on the training set to make it more representative of training data. We look into regularization techniques such as increasing dropout % or adding additional dropout layers and tinker with the batch size or lower the learning rates. We wonder if our model is too complex and think about reducing the number of weights in layers. While the aforementioned techniques can be useful, some of these us down a rabbit hole and take significant time and effort. I tried some of these methods but was not successful. So, the most logical step was to look at the data.
Since this data is synthetically generated for each class using our own code, the data cannot be mislabelled (right?). We’ll get to that later..
Looking at all of the data is rough. So I took the poorly trained model above and used it for inference on the validation dataset. Misclassified data was analyzed. If you want to get deep into analyzing data, you can set confidence thresholds for best predicted class and further analyze the images below that threshold.
We first check if there are bad images which are not representative of real world data. One example is shared below. Excessive augmentation resulted in font size being too small and the image blurry and filled with noise. If your dataset is huge, scanning for such examples manually may be like finding a needle in a haystack and isn’t practical.
Looking at the garbage data is just a suggestion if the size of dataset is reasonable. Determining whether an image is garbage is also subjective in nature. For the example below, you can ask yourself if you will encounter images like this in real world. If the answer is ‘Yes’, it is better to keep these images in the dataset or else to tune down your data augmentation techniques to avoid generating such images.
While looking at the data I noticed that some of the data from different classes looked very similar. My classification task classifies font styles and some look very similar to humans. CNN can be trained to detect these. But let’s not forget that on top of the font styles being similar, we are adding augmentation which can make things worse. It is good idea to verify if this is the cause. So our leading hypothesis so far is that the data from two or more classes is similar or exactly the same for unknown reasons.
The next step should be can you recreate the problem?
Recreating the Problem
I decided to experiment with a small number of classes rather than the 24 classes I originally used. I started with classes that are not similar – Calibri light and Google fonts Tino Regular and Gelasio. As you can see below, these classes are far apart. I removed data augmentation steps 1-4 outlined above and kept 5 and 6.
As per our hypothesis, I expect our network to converge. Upon training with the original set of hyper-parameters, the results were as expected.
Now I will try to prove my initial hypothesis, ie, some classes are too similar. To test, I added other variations of Calibri as I expected these variations to be very similar. I added Calibri Bold Italic and Calibri Regular.
So now our set of classes are:
Calibri Light, Calibri Regular, Calibri Bold Italic, Tinos-Regular and Gelasio. For data augmentation, I kept data augmentation #5 and #6 and removed 1-4 (same as previous experiment) and I can recreate the issue.
So are we implying that the model is not capable of distinguishing between Calibri Regular and Calibri Light? At this stage, you may be tempted to look into model complexity or may want to train model with different set of hyper-parameters.
Before we go down this rabbit hole, lets look into our data augmentation. In these experiments, we used augmentation techniques ‘Variation in Text Color'(#5) and ‘Variation in Font Size'(#6). For our text color augmentation techniques, I was simply varying intensity of pixels (0-255). It is possible, I may be creating augmented samples for Calibri Regular that resemble Calibri Light or creating augmented samples for Calibri Light that resemble Calibri Regular.
Let’s test this hypothesis by removing color augmentation, while keeping the five classes we tested and augmentation#6. Model training hyper-parameters are the same. Below is the result of training 100 epochs. We, see after removal of color augmentation, Training and Validation Accuracy start to converge. The loss curve converges nicely as well.
A second look at our dataset prevented us from spending a lot of time on other investigation routes. I realized I was polluting my own dataset! This case study, also tells us that we should pick our data augmentation methods wisely. The data augmentation methods should be specific to your use case. Going with standard image data augmentation methods, without thorough reasoning should be strictly avoided.
Ankit Goyal is a Machine Learning Engineer at Informed IQ. He focusses on using Deep Learning and Transformers to automate the lending process and is passionate about using ML algorithms in the Computer Vision field.