Hi all, hope you are well!
In this issue I'll talk about some of my recent experiments building a simple text classification model (predict if a tweet is related to politics or not) and debugging/inspecting the model's behaviour using gradients.
Context - Understanding Election Discourse in Developing Countries
An Interactive Visualization of Election Discourse
Parsing the Nigerian 2019 Presidential Elections Discourse Data
I have always been passionate about applying machine learning techniques in understanding social and demographic issues especially pertaining to understudied regions (e.g., developing countries). A while ago I assembled a dataset of tweets (25 million tweets) spanning 6 months leading up to the February 2019 Nigerian general elections. The goal was to analyze this data (see some initial results here) and answer questions around citizen participation, network structures, demographic distribution and even wealth disparities in the country.
While the topics here are quite broad, one specific area where machine learning can help is with respect to understanding election specific discourse. The first step here is to reliably extract a subset of all tweets that pertain to politics and then ask further questions within this subset of data (e.g., sentiments associated with tweets, changes in sentiment overtime, sentiments associated with events, hashtags associated with political discourse, sentiment around candidates etc).
To address this task, I chose to build a BERT based model, using the huggingface transformer library and Tensorflow 2.0. A competent politics classifier model would be useful in annotating all tweets in the data, enabling new types of analysis.
Dataset Building
All 25 million tweets from the data collected are currently stored in a BigQuery table. This allows the generation of derivative tables that are used to assemble a training dataset. Traditionally, a machine learning dataset is typically built by human annotators that assign a label to each tweet. However, you can go quite far by writing heuristics/rules that assign an initial set of labels which are used to train a baseline model. I explored this route.
Assemble a list of politics related keywords (e.g., candidate names, and terms such as election, president, presidential, senate, house of representatives, etc)
Run queries that match tweets against regexes built using the keywords above. Matches are given a "political" label, an equivalent size of tweets is drawn from the rest as the "not political" class. Care has to be taken in selecting keywords and regexes. For example, a keyword list that is too general results in a model that thinks everything is about politics.
Model Building with Huggingface Transformers
Huggingface transformers makes the process of building models really straightforward. The bulk of the work usually goes towards writing an efficient data pipeline used to train the model and get predictions.
Some of the high level steps I took include:
Training:
Shuffle and chunk large datasets smaller splits.
Tokenize text for each split and construct a
tf.data
object.Iteratively train model on each split. (.. reserve a split for evaluation as needed).
Prediction:
Chunk data into splits
Apply batching optimizations to speed up prediction on each split. For example, construct batches based on text sorted by length, such that similar sized text are in the same batch.
Predict on batches and aggregate results.
Additional details and code snippets for implementing a text classification model in Tensorflow are here https://victordibia.com/blog/text-classification-hf-tf2/
Model Debugging - Gradient Explanations
Once you have a trained model, it is tempting to assume it works well in practice when its evaluation metric is high (e.g., accuracy, sparse categorical accuracy etc). In my case, training accuracy was consistently above 98%. On one had, this is expected as the task is relatively simple (on the minimum is should learn to pay attention to salient politics related keywords in the train set and generalize this to other common related keywords). On the other hand, it is important to verify that the model is not exploring spurious short cuts.
One way to do this is to explore gradient based attributions as a way to explain model behaviour. By looking at the gradients of outputs from a deep learning model with respect to its inputs, we can infer the importance/contribution of each token to the model's output. This approach is sometimes referred to as vanilla gradients or gradient sensitivity.
In Tensorflow/Keras, obtaining gradients (via automatic differentiation) can be implemented using the Tensorflow GradientTape API. The general steps are outlined below:
Initialize a GradientTape which records operations for automatic differentiation.
Create a one hot vector that represents our input token (note that input correspond to the token index for words as generated by the tokenizer). We will instruct Tensorflow to
watch
this variable within the gradient tape.Multiply input by embedding matrix; this way we can backpropagate prediction wrt to input
Get prediction for input tokens
Get gradient of input with respect to predicted class. For a classification model with
n
classes, we zero out all the othern-1
classes except for the predicted class (i.e., class with the highest prediction logits) and get gradient wrt to this predicted class.Normalize gradient (0-1) and return them as explanations.
Additional details and code snippets for implementing gradient explanations are here https://victordibia.com/blog/explain-bert-classification/.
By inspecting gradients, there are a few insights and changes I made to the experiment setup
Adding Custom Tokens: Visualizing tokens and their contributions helped me realize the need for expanding the token vocabulary used to train the model. My dataset had alot of Nigerian names that the standard BERT tokenizer did not contain leading them to be represented by partial tokens. I used the
`tokenizer.add_tokens`
method and resized my model embedding size to fix this.Informing strategies for improving the model: Looking at how importance is assigned to each token was useful in rewriting the initial heuristics used on generating a training set (e.g., not matching on some frequently occuring last names to minimize spurious attributions), introducing additional keywords, updating my preprocessing logic etc.
Note that there are other methods/variations of gradient based attribution, however vanilla gradients are particularly straightforward to implement and yields fairly similar results.
Conclusions
Thanks for reading! Hopefully, the steps above are useful for any text classification or tweet classification projects you are working on. Reach out on twitter to discuss any related topics!
Be well!