Spam Detection: Train in one language, Predict in another language? | Zero-Shot Learning | Pytorch
Zero-Shot Learning is something that has been in the AI-ML industry for a long time. They have been used in Images extensively and with the new SOTA NLP models like Sentence Transformers, they are now being used in NLP tasks with more robustness. This blog doesn't focus on the details of Zero-Shot Learning but would touch on some high-level basics. The blog link mentioned below explains the in-depth logic very clearly. I would recommend reading that link. In simple terms, Zero-Shot Learning usually means,
We can train our classifier on one set of labels and then use the model to predict it on another set of labels that it hasn't seen before.
So in Image classification tasks, they use existing featurizer to embed an image and any possible class names into their corresponding latent representations (e.g. Socher et al. 2013). They can then take some training set and use only a subset of the available labels to learn a linear projection to align the image and label embeddings. At test time, this framework allows one to embed any label (seen or unseen) and any image into the same latent space and measure their distance.
In the text domain, we have the advantage that we can trivially use a single model to embed both the data and the class names into the same space, eliminating the need for the data-hungry alignment step.
— https://joeddav.github.io/blog/2020/05/29/ZSL.html
How they did the train the network?
Paper: https://arxiv.org/abs/1712.05972
1) They collected 300,000 unique SEO tags as the labels crawing the web.
2) They created word2vec pre-trained on Google News as word embeddings for both the sentences as well as the labels.
3) Three Architectures were
— 3.1) Using the mean of the embeddings and concatenating it and passing it through FF.
— 3.2) Using LSTM and passing the last hidden layer to concat with labels and FF.
— 3.3) Mixed label and word pre-LSTMS and use the last hidden layer on FF.
This posts explains a simple Spam classifier build using Pytorch. It shows how the model is trained in English and it is used to detect spam messages in French at ease.
I have made use of this dataset from https://www.kaggle.com/uciml/sms-spam-collection-dataset (Kaggle) where we have two columns, one which tells us if it is Spam or Not and another which represents the message associated with it. For the purpose of testing in another language, I translated the message into French for every row into another column. So the dataset looks like this:
So now i have created a simple Pytorch function that detects Spam messages. It is very simple. Let us see the different functions and its purpose.
Function 1:
The first step is to create a simple Classifier with Pytorch. In the below function we are just passing the embedding, the number of labels (Spam or No Spam), and the dropout we need.
Creating two simple layers, where we first pass the embeddings we get into a dropout layer and then we pass it to a fully connected layer. The Class returns both the Tensor and the Softmax of the predictions which gives out the probability of the two labels.
Function 2:
We create a standard mini-batcher script. We can run the model in mini batch, stochastic, or using batch method. I tend to use mini-batch most of the time as I feel it is well regularised and the model trained is more accurate.
Function 3:
We create a simple Label encoder and transform the labels (at present it is spam and ham) into labels. Now we split the data into train and test. We do it twice because we need to do the same split for the English column and the French column. Just wanted to take the same test set in both the languages and test it so that we can compare the accuracy in both cases.
We make use of Sentence Transformer and the
quora-distilbert-multilingual model from Huggingface which is trained on multiple languages from Quora. Then we encode both the train and test using this Transformer.
Function 4:
We run the Batcher and have 10 epochs to train the model with backprogpation and Adam optimiser. Once the model is trained we also test the model with the french test data.
Though there is some sense of overfitting but the model can easily achieve early 80’s at ease. The results are as follows.
These Zero Shot Encoding techniques are very handy when we need to use them in applications like reading the comments, ratings, chatbot conversations in multiple languages to a single language. There are models which helps you to predict Hinglish or mixed languages with a 70’s accuracies base. This is a very easy powerful model which can save you time.
Would definitely love to know how you are planning to use it in real life applications and in your organisations in the comment. Also feel free to post interesting models related to Zero-Shot learning in the comments.
Regards and Thanks,
Vigneshwar
Please find the full code at : https://github.com/Vickyilango/SpamDetection-ZeroShotLearning.git
References: