Once you have a trained model, it is tempting to assume it works well in practice when its evaluation metric is high (e.g., accuracy, sparse categorical accuracy etc). In my case, training accuracy was consistently above 98%. On one had, this is expected as the task is relatively simple (on the minimum is should learn to pay attention to salient politics related keywords in the train set and generalize this to other common related keywords). On the other hand, it is important to verify that the model is not exploring spurious short cuts.
One way to do this is to explore gradient based attributions as a way to explain model behaviour. By looking at the gradients of outputs from a deep learning model with respect to its inputs, we can infer the importance/contribution of each token to the model’s output. This approach is sometimes referred to as vanilla gradients or gradient sensitivity.
In Tensorflow/Keras, obtaining gradients (via automatic differentiation) can be implemented using the Tensorflow GradientTape
API. The general steps are outlined below:
- Initialize a GradientTape which records operations for automatic differentiation.
- Create a one hot vector that represents our input token (note that input correspond to the token index for words as generated by the tokenizer). We will instruct Tensorflow to
watch this variable within the gradient tape.
- Multiply input by embedding matrix; this way we can backpropagate prediction wrt to input
- Get prediction for input tokens
- Get gradient of input with respect to predicted class. For a classification model with
n classes, we zero out all the other
n-1 classes except for the predicted class (i.e., class with the highest prediction logits) and get gradient wrt to this predicted class.
- Normalize gradient (0-1) and return them as explanations.
By inspecting gradients, there are a few insights and changes I made to the experiment setup
- Adding Custom Tokens: Visualizing tokens and their contributions helped me realize the need for expanding the token vocabulary used to train the model. My dataset had alot of Nigerian names that the standard BERT tokenizer did not contain leading them to be represented by partial tokens. I used the
`tokenizer.add_tokens` method and resized my model embedding size to fix this.
- Informing strategies for improving the model: Looking at how importance is assigned to each token was useful in rewriting the initial heuristics used on generating a training set (e.g., not matching on some frequently occuring last names to minimize spurious attributions), introducing additional keywords, updating my preprocessing logic etc.
Note that there are other methods/variations of gradient based attribution, however vanilla gradients are particularly straightforward to implement and yields fairly similar results