In this post, we will explore into some of the more recent imbalanced dataset studies.
Imbalanced Data Problem:
In machine learning it is one of common problems one may come across while training a dataset that the accuracy performance is extremely good but it only seemed so due to the dataset distribution of one class being far greater than the other. Hence the model learns to explore the fact that, if it produces the abundant class, the loss will be likely small.
There are two representative methods to handle this problem.
- Sampling Methods: Aims to balance the class priors with sampling methods such as, Oversampling, Undersampling, Informed Undersampling, Synthetic Sampling with Data Generation, Sampling with Data-Cleaning. Oversampling gives a high risk of overfitting to the imbalanced class, while Undersampling has a disadvantage that it does not make use of the full amount of data you have.
- Cost-Sensitive Methods: Instead of modifying the data, you consider the cost of misclassification. E.g. Modify learning rate for imbalanced examples.
Let’s have a look at Learning Deep Representation for Imbalanced Classification by Huang et al 2016. The authors suggest a method to enforce a deep network to maintain both inter-cluster and inter-class margins, which would lead to a better discriminative representation being learned from datasets with imbalance. They hypothesize this can be achieved by using Quintuplet Sampling with Triple Header Hinge loss.
Let’s first have a quick high level overview of how this might work.
I will break down the training process into 6 steps, and attempt to explain some details along the way.
Begin by clustering by k-means on the feature embeddings produced by the trained ConvNet. At the very first step, the ConvNet is not at all trained so meaningless features might be present. Hence, in practice authors suggest to use a pre-trained network to obtain the features.
Generate and fill up a “Quintuplet Table” using the cluster and class labels.
for each sample in the dataset, we create a randomly sampled quintuplet.
A quintuplet is chosen by the following rule below in the figure.
where, f() is the embedding function, our training convolution net usually. D() is a distance/similarity metric we may choose.
The constraint’s variables,
: most distant within-cluster neighbor.
: nearest within-class neighbor of anchor from a different cluster.
: most distant within-class neighbor of anchor.
: nearest between class neighbor of the anchor.
Hence, for each sample(anchor) in the dataset, you will be sampling for the other 4 that satisfies the above constraint. The authors put a constraint within class as well as between class.
Now you have the dataset ready, you uniformly sample batches from each class in the Quintuplet table.
Feed the sampled batches into the shared parameter ConvNet to produce a feature embedding for each element of the quintuplet, and compute the triple header hinge loss using it.
Triple-header Hinge Loss
Let’s look at the marginal loss of a first inequality above, as the same applies for the other two. Let’s use the figure below to ease our understanding.
The green is the margin , red is the distance from anchor to most-distant within cluster neighbor , blue is the distance from anchor to nearest within class cluster neighbor .
Going back to the equation, we want the difference of the two distances to be negative, even with margin . We want the model to learn such that the distance between anchor to within cluster are less than distance of anchor to nearest within class cluster. If that is already satisfied even with margin, we give 0 penalty. However, if the value is positive, we include it in the loss.
Based on the loss obtained above, back propagate to update the ConvNet parameters.
Repeat from step 0-4 every 5000 iteration until validation performance is satisfied.
Nearest Neighbor Imbalanced Classification
Original KNN classifier does not satisfy the underlying equal-class density assumption hence, unfair to use it with imbalanced data.
2 modification made to original kNN:
- In the well-clustered feature space learned in training stage, treat each cluster as a single class-specific exemplar, and perform a fast cluster-wise kNN search.
- Let be query q’s local neighborhood defined by its kNN cluster centroids .
Cluster-wise kNN search
For query q,
- Find q’s kNN cluster centroids for all classes learned in the training stage.
- If all the local k cluster neighbors belong to the same class, q is labelled by that class and exit.
- Otherwise, label q as using equation above.
3 explained, Label q as the class to which the maximum cluster distance is smaller than the minimum cluster distance to any other class by the largest margin.
In other words, we want to choose a class,
- That wants the min distance to any other class cluster than itself to be maximum
- And wants the max distance to its own class cluster to be minimum.
Experiment Results on CelebA dataset
As you move right class imbalance increases for that face attribute. The colors represent the relative accuracy gain over each of the competition models shown in the legend.
Feel free to have a look at 2 more experiments done by the authors on edge detection task and MNIST-rot-back-image.
The authors keep into consideration that Imbalance dataset not only exists between classes but within a single class as well by assimilating it into the Quintuplets and hinge loss function. I do ponder if an improvement can be made by applying this loss on not just the final layer embedding but rather on a weighted average of each of the layers where it will contain differing feature representations.
C. Drummond and R. C. Holte. C4.5, class imbalance, and cost sensitivity: Why under-sampling beats over-sampling. In ICMLW, 2003.
H. He and E. A. Garcia. Learning from imbalanced data. TKDE, 21(9):1263–1284, 2009.
P. Jeatrakul, K. Wong, and C. Fung. Classification of imbalanced data by combining the complementary neural network and SMOTE algorithm. In ICONIP, 2010.
H. Chen, L. Yining, L.C. Chen, T. Xiaoou. Learning Deep Representation for Imbalanced Classification. In CVPR, 2016.