from fastai.vision.all import *
Introduction
In a previous post we saw our Bear classifier confidently predict that an image of a Maine Coone was a Teddy bear. Roughly speaking the problem, in a multi-class setting, is that the exponential (in the softmax) pushes the class with the highest activation to receive a score close to \(1\). So while our classifier had never seen a Maine Coone when training, at inference time the image of the Maine Coone happened to push the activations for the Teddy Bear class higher than those for the other classes (Black Bear and Grizzly).
In this post we use multi-label classification to solve this problem following Chapter 6 of [1] and the note in [2]. A multi-label classification problem (as opposed to multi-class classification) is one where there can be multiple labels, instead of there being a single correct class, for every object.
We want our multi-label classifier to predict the different labels present in the object or no label if it is not confident about the presence of any of the labels. For this we switch to using Binary Cross Entropy as our loss function (and will also come up with a threshold probability for our classifier to use for inference).
Binary Cross Entropy
Suppose the following are the activations, on a single training example, for the different labels possible in a multi-label classification problem. Pretend that we can have as many as three labels for each image.
= torch.randn(1,3)*3
activations activations
tensor([[ 0.7260, -3.3489, -2.3491]])
Next suppose the correct labels in this case are the first and third label. A one-hot encoded representation of the labels for this single training example is as follows.
= tensor([[1,0,1]])
targets targets
tensor([[1, 0, 1]])
The first step is to take the sigmoid activations for each label to convert each activation into a probability score.
= activations.sigmoid()
sigmoid_activations sigmoid_activations
tensor([[0.6739, 0.0339, 0.0871]])
The next step is to compute the binary cross entropy loss using
= -torch.where(targets==1, sigmoid_activations, 1-sigmoid_activations).log()
bce_loss bce_loss
tensor([[0.3946, 0.0345, 2.4403]])
When a label is present, in the target, we take the sigmoid activation corresponding to that label. Otherwise, when the label is absent, we take one minus the that sigmoid activation. By doing this we are asking for the confidence that the classifier places on the particular label being absent.
Thus the binary cross entropy function summarizes the “correctness” of the classifer across the presence and absence of each possible label in a training example. Contrast this with the use of cross entropy loss in multi-class classification where we only care about the probability being predicted for the single correct class of the training example.
= torch.reshape(torch.cat([activations, sigmoid_activations, targets, bce_loss], 0),(4,3))
results =['activations', 'sigmoid','target','loss']) pd.DataFrame(results, index
0 | 1 | 2 | |
---|---|---|---|
activations | 0.725968 | -3.348915 | -2.349130 |
sigmoid | 0.673920 | 0.033931 | 0.087135 |
target | 1.000000 | 0.000000 | 1.000000 |
loss | 0.394644 | 0.034520 | 2.440297 |
Each row in the dataframe above shows us the step needed to go from an activation to the loss associated with each possible label of a single training example. We can get the losses directly by
float(), reduction="none") F.binary_cross_entropy_with_logits(activations, targets.
tensor([[0.3946, 0.0345, 2.4403]])
Or by
='none')(activations, targets.float()) nn.BCEWithLogitsLoss(reduction
tensor([[0.3946, 0.0345, 2.4403]])
Of course we will want to take the mean of the loss across the labels using any of the two methods below
float()), nn.BCEWithLogitsLoss()(activations, targets.float()) F.binary_cross_entropy_with_logits(activations, targets.
(tensor(0.9565), tensor(0.9565))
Soft Labels
In the preceding the labels were hard. In other words, we knew with certainty the labels that are present for any training example. What if we only knew the probability with which a label would be present in a training example?
Feel free to skip to the next section as this is somewhat of a digression.
In the example below we have four training examples and only a single label. The activations for the training examples are as below:
= tensor([[0.0122],[0.2294],[-0.1563],[-0.1483]])
activations activations.shape
torch.Size([4, 1])
Apply sigmoid to get the predictions for whether the label is present.
= activations.sigmoid()
sigmoid_activations sigmoid_activations
tensor([[0.5030],
[0.5571],
[0.4610],
[0.4630]])
Get the predictions for whether the label is absent (i.e., some other label is present).
= (1. - sigmoid_activations)
one_minus_sigmoid_activations one_minus_sigmoid_activations
tensor([[0.4970],
[0.4429],
[0.5390],
[0.5370]])
The targets for each example is as follows. This says that in the first training example label 0 is present with probability 0.5, similarly the probability for label 0 being present in the second example is 0.29 and so on.
= tensor([0.5, 0.29, 0.36, 0.03])
targets targets.shape
torch.Size([4])
Compute the binary cross entropy loss for each training example.
= -1 * ( targets*sigmoid_activations.flatten().log() + (1. - targets)*( one_minus_sigmoid_activations.flatten().log() ) )
bce_loss bce_loss
tensor([0.6932, 0.7479, 0.6743, 0.6262])
Compute the average binary cross entropy loss across training examples.
bce_loss.mean()
tensor(0.6854)
The above can just be done as a one-liner using BCEWithLogitsLossFlat
='mean')(activations, targets) BCEWithLogitsLossFlat(reduction
TensorBase(0.6854)
Note that we cannot just use nn.BCEWithLogitsLoss because the input tensor isn’t flattened and we will get an error.
ValueError: Target size (torch.Size([4])) must be the same as input size (torch.Size([4, 1]))
Now suppose we have two examples and three soft labels per example.
= tensor([[0.0122, 0.001, 0.5],[0.2294, 0.3, 0.75]])
activations activations.shape
torch.Size([2, 3])
The soft labels are:
= tensor([0.9, 0.1, 0.2],[0.25, 0.5, 0.33 ])
targets targets.shape
torch.Size([2, 3])
We can compute the binary cross entropy loss in the cumbersome way as below:
= -1 * ( targets * activations.sigmoid().log() + (1. - targets)*( 1.0 - activations.sigmoid() ).log() )
bce_loss sum(), bce_loss.mean() bce_loss, bce_loss.
(tensor([[0.6883, 0.6935, 0.8741],
[0.7571, 0.7044, 0.8894]]), tensor(4.6067), tensor(0.7678))
Or just use the one-liner to get the average loss.
='mean')(activations, targets) nn.BCEWithLogitsLoss(reduction
tensor(0.7678)
I find it useful to think of each individual training example as being replicated as many times as there are possible labels. Then 2 training examples each with 3 possible labels is better thought of as 6 training examples. Hence the average loss, in this setting, requires us to divide by \(2*3 = 6\).
Note that we could reasonably think that the loss for each label would be summed up into a single loss per training example and we would then divide that by \(2\) but do me a favor and don’t walk down that path!
Bear Images
Now let’s get back to our multi-label classification exercise. So we start by assembling images of bears.
Code
from jmd_imagescraper.core import *
def scrape_images(path, labels, search_suffix, erase_dir=True, max_images=20):
if erase_dir:
!rm -rf {path}
if not path.exists():
=True)
path.mkdir(parents
for some_label in labels:
\
duckduckgo_search(path, some_label,f'{some_label} {search_suffix}', max_results=max_images)
= get_image_files(path)
filenames = verify_images(filenames)
failed
map(Path.unlink);
failed.if failed != []:
= [filenames.remove(f) for f in failed]
_
# To avoid Transparency warnings, convert PNG images to RGBA
# https://forums.fast.ai/t/errors-when-training-the-bear-image-classification-model/83422/9
= L()
converted for image in filenames:
if '.png' in str(image):
= Image.open(image)
im # old file name before resaving
converted.append(image) "RGBA").save(f"{image}2.png")
im.convert(
map(Path.unlink); # delete originals
converted.
= len(get_image_files(path))
total_images print(f"After checking for issues, {total_images} (total) images remain.")
return path
= 'grizzly','black','teddy'
labels = scrape_images(Path('/data/kaushik/bears'), labels, 'bear', max_images=100) path
Duckduckgo search: grizzly bear
Downloading results into /data/kaushik/bears/grizzly
Duckduckgo search: black bear
Downloading results into /data/kaushik/bears/black
Duckduckgo search: teddy bear
Downloading results into /data/kaushik/bears/teddy
After checking for issues, 300 (total) images remain.
Set up Data Block
We will need a multicategory block and a function that takes a single label and converts it into a list.
def get_label_list(some_label): return [some_label]
= DataBlock(
dblock =(ImageBlock, MultiCategoryBlock),
blocks=get_image_files,
get_items=RandomSplitter(valid_pct=0.2,seed=0),
splitter= Pipeline([parent_label, get_label_list]),
get_y=RandomResizedCrop(228, min_scale=0.5),
item_tfms=aug_transforms())
batch_tfms
= dblock.dataloaders(path) dls
=8, nrows=2) dls.show_batch(max_n
Train Model
For the model we will use the fastai’s BCEWithLogitsLossFlat to compute the binary cross entropy loss. Another item to note is the use of accuracy_multi as the metric. Accuracy multi is defined as
def accuracy_multi(inp, targ, thresh=0.5, sigmoid=True):
"Compute accuracy when `inp` and `targ` are the same size."
inp,targ = flatten_check(inp,targ)
if sigmoid: inp = inp.sigmoid()
return ((inp>thresh)==targ.bool()).float().mean()
This shows that accuracy_multi uses the threshold to determine whether a label is present and then takes a mean across the labels for every example in the inputs. The training process does not care about the threshold it only cares about minimizing the loss function.
= cnn_learner(dls, resnet18, metrics=[partial(accuracy_multi, thresh=0.95),\
learn
APScoreMulti()], =BCEWithLogitsLossFlat())
loss_func4) learn.fine_tune(
epoch | train_loss | valid_loss | accuracy_multi | average_precision_score | time |
---|---|---|---|---|---|
0 | 1.037557 | 0.579388 | 0.783333 | 0.868948 | 00:07 |
epoch | train_loss | valid_loss | accuracy_multi | average_precision_score | time |
---|---|---|---|---|---|
0 | 0.497570 | 0.354649 | 0.883333 | 0.960842 | 00:09 |
1 | 0.386510 | 0.170127 | 0.961111 | 0.987437 | 00:09 |
2 | 0.306835 | 0.135320 | 0.966667 | 0.989491 | 00:09 |
3 | 0.263357 | 0.124831 | 0.966667 | 0.989491 | 00:09 |
Let’s look at the results.
learn.show_results()
Sadness revisited
Let’s test with our Maine Coon image.
= '../test_images/mc.jpg'
test_mainecoon_image_location = Image.open(test_mainecoon_image_location)
im 128,128) im.to_thumb(
learn.predict(test_mainecoon_image_location)
((#2) ['black','teddy'],
tensor([ True, False, True]),
tensor([0.5806, 0.0633, 0.8982]))
What happened? Our multi-label classifier still chooses to predict some of the labels.
Looking at the predicted probabilities we see that none has a score of greater than 0.95 so this means that the threshold we passed in for accuracy_multi is not being used during inference.
If we look at the source of learn.predict we see that it calls out to learn.get_preds(…, with_decoded=True). Per the documentation of get_predict with_decoded=True will also return the decoded predictions using the decodes function of the loss function (if it exists).
Looking at the definition of BCEWithLogitsLossFlat we see the presence of the decodes method.
class BCEWithLogitsLossFlat(BaseLoss):
"Same as `nn.BCEWithLogitsLoss`, but flattens input and target."
@use_kwargs_dict(keep=True, weight=None, reduction='mean', pos_weight=None)
def __init__(self, *args, axis=-1, floatify=True, thresh=0.5, **kwargs):
if kwargs.get('pos_weight', None) is not None and kwargs.get('flatten', None) is True:
raise ValueError("`flatten` must be False when using `pos_weight` to avoid a RuntimeError due to shape mismatch")
if kwargs.get('pos_weight', None) is not None: kwargs['flatten'] = False
super().__init__(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
self.thresh = thresh
def decodes(self, x): return x>self.thresh
def activation(self, x): return torch.sigmoid(x)
Thus, all we need to do is pass the threshold into BCEWithLogitsLossFlat.
= BCEWithLogitsLossFlat(thresh=0.95)
learn.loss_func learn.predict(test_mainecoon_image_location)
((#0) [], tensor([False, False, False]), tensor([0.5806, 0.0633, 0.8982]))
Much better. Now none of the labels are predicted.
To get the best threshold we first get all the predictions on the validation set.
= learn.get_preds() preds,targs
Since the predictions from get_preds has the sigmoid applied already we will pass sigmoid=False when assessing the accuracy_multi at different threshold values .
= torch.linspace(0.05,0.995,100)
xs = [accuracy_multi(preds, targs, thresh=i, sigmoid=False) for i in xs]
accs ; plt.plot(xs,accs)
Looks like our choice of threshold is fine. One more thing to note is that for multi-label problems it is useful to look at the average precision score using APScoreMulti. This is because on harder datasets it will provide a decidely less rosier view of classifier performance than accuracy_multi.
As a future step it would be nice to do threshold selection for each label following [3].