Deep Text Classifier
Deep Text Classifier
Problem Description
We want to build a classifier for the text input. For example, we may want to classify a company's industry area based on its description. Or we may want to classify a company's IPO status based on its description.
General Approach
We will build a deep neural network to uniformly solve this problem. The traditional way of doing this is to hire a task specific expert to manually design some useful features, say to check if the text contains words "Internet" and "High-tech" at the same time, and to classify based on the observed features. Our way, by using the deep neural network, can automatically extract the features and most importantly achieve very high testing accuracy. However, the features that are used by the deep neural network are not human interpretable.
About the Deep Models
There are basically two big categories of deep neural networks - the convolutional neural networks (CNN) and the recurrent neural networks (RNN). The first one, CNN, is more suitable for dealing with the image based classification tasks. The second one, RNN, is in general for sequential information (i.e. language, video ...) based classification tasks.
Major Package Dependences
- Tensorflow https://www.tensorflow.org/
- Numpy http://www.numpy.org/
- Keras https://keras.io/
How to Run the Code
The code contains two parts: Data Preprocessing and Model Training/Prediction.
Data Preprocessing (preprocessing.py) : this is where you transfer a text based "XXX.txt" input file into a numerical value based pickle file that the later part of the code can understand and use for training and prediction.
- Step 1 : modify the target file name in "main()"
# don't add ".txt" extension file_name = 'ThicketDefCodingTestProcessed'
- Step 2 : specify the expected columns of your target file in "main()"
# expected number of columns, in case we have "None" in the table expected_columns = 5
- Step 3 : specify the indices of the text and the label in "prepare_imdb_structure(file_name, expected_columns)"
# the index of the label in the tokens label_index = 1 # the index of the text in the tokens content_index = 4
- Step 4 : run the code
python preprocessing.py
Model Training/Prediction (classification_MMM_LLL.py) : this is where the deep neural network is. The "MMM" represents the model. For example, currently I have "1DConvolution", "2DConvolution" and "LSTM". "LLL" represents the name of the label. Notice that for the same text inputs we can predict for different things using the same model literally. For example, "classification_LSTM_indu.py" is a LSTM model to predict the industray based on the descriptions. And "classification_LSTM_ipo.py" is a LSTM model to predict the IPO status based on the same descriptions. You need to name your file properly. This Python file, no matter what the model is, will always load in a pickle file you generated in the previous step and train the neural network. At the end, the well trained neural network will predict on your testing examples (the examples you don't see during the training) and print the accuracy. To run this part:
python classification_LSTM.py
Notice that the data preprocessing part usually only needs to be done once. The saved pickle file is basically a machine friendly code that can be loaded very fast.