Let me answer by steps.
1) How metrics are calculated and why you may get NaN
Metrics are calculated as:
- accuracy = (TP + TN) / N
- precision = TP / (TP + FP)
- recall = TP / (TP + FN)
- specificity = TN / (TN + FP)
During a traning procedure, these metrics are applied to batches.
Therefore, you may have batches (in the example, batch_size = 16) where the denominator of those fractions is zero, causing NaN. It is perfectly possible to have, let's say, a precision equal to NaN in absence of TP and FP in the evaluated batch. The only metric that can never be NaN is accuracy, as the denominator is the number of samples in the batch, independently from their class.
In our example, we chose to assess validation metrics as an average on the different batches ignoring NaN values. You may modify this approach as you wish.
2) How you may have obtained Precision = NaN, Recall = 0 and Specificity = 1
Let us consider a batch like this one:
preds = [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
labels = [0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0]
In this case, metrics assume the exact values you shared:
- tp = 0
- tn = 15
- fp = 0
- fn = 1
- b_precision = tp / (tp + fp) = nan
- b_recall = tp / (tp + fn) = 0.0
- b_specificity = tn / (tn + fp) = 1.0
- b_accuracy = (tp + tn) / n = 0.9375
There are multiple batches evaluated in one epoch.
If this is the output of one epoch (multiple batches), it would imply that your input observations are rarely true positives across the batches for that epoch. Here, the hint of absence of true positives is given by the values assumed by precision and recall.
3) "It happens randomly with different datasets"
In case you were trying to use this code to solve a problem with different input data, you may need to adapt the code or hyperparams depending on the problem, and also verify your input data.
For example:
- Are you trying to solve a binary classification problem (2 labels: 0 vs. 1) or is it a multi-class classification problem (more than 2 labels)?
- Is your dataset balanced? (i.e. you have a similar number of observations for each label)
- How are the labels distributed in the train and validation sets?
- Did you try with different hyperparameters values among the ones suggested by the BERT paper? For example, increasing the batch size from 16 to 32.
4) "After that, the model outputs same predictions on different outputs"
It is only a guess (I haven't seen the data/code), but it really might be that your model always saw batches like this one during training:
preds = [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
labels = [0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0]
i.e. with very few/no true positives, and therefore only "learned" to predict class 0 (for any input). I would suggest checking the points above.
Let me know if this helps.
Best,