Comparing Support Vector Machines and Decision Trees for Text Classification

It's become easier and easier for non-experts to use machine learning to solve real-world problems in the past decade.

However, we're still far from being able to press a "do machine learning" button without having any underlying knowledge. Although it's easy to use a machine learning framework, like scikit-learn, to train and use machine learning in a few lines of Python, you still need to choose which of the multitude of different algorithms to use.

In this tutorial, we'll compare two popular machine learning algorithms for text classification: Support Vector Machines and Decision Trees. To follow along, you should have basic knowledge of Python and be able to install third-party Python libraries (with, for example, pip or conda). We'll be using scikit-learn, a Python library that includes an implementation and standard interface for many different machine learning algorithms.

Examples of text classification problems

Before we get started, let's take a quick step back and look at text classification more generally.

The main advantage of machine learning is the same code can be used to solve many different problems. Historically, it was common practice to build large and customised domain-specific rulesets, which required domain experts, programmers, and hundreds of thousands of lines of code. Now, entire categories of problems can be solved by same code by simply swapping out the data that is used for learning.

All of the following issues can be thought of as text classification problems and can be solved by using the methods outlined in this tutorial.

  • Spam detection: Spam went from a huge, expensive problem that plagued nearly everyone who used Email to something that is almost forgotten about, thanks to advances in machine learning and text classification. Spam checkers look at the full text of incoming emails and automatically assign one of two labels: "Spam" or "Not spam" (also often referred to as "spam" and "ham").
  • Sentiment analysis: Deciding if a specific piece of language is, overall, "happy" or "sad" (or sometimes more fine-grained labels, such as "angry" or "neutral") is an interesting problem that is used to do everything from automatic stock trading, to managing and monitoring the reputation of a specific brand.
  • Authorship attribution: Anyone who strings words together in any context has their own unique style. Sometimes people try to stay anonymous (for example, when writing ransom notes), or pretend to be someone else (for example, plagiarism and artistic forgery). Automatically working out whether a specific person wrote a given piece of text is a kind of text classification that can be applied to fields, such as education and forensic linguistics.
  • Content moderation: If you've ever used the Internet, you probably already know how nasty the comment sections can get, especially when it gets political. It would be impossible for humans to moderate this content, deleting hate speech and other undesirable content, without the assistance of text classification algorithms that can automatically detect if a specific comment should be published or banned.
  • Customer support and triage: Customers expect fast responses to their troubles, and companies can give them this by automatically routing any incoming messages to the right people. How urgent is the problem? How difficult? Can it be resolved automatically? Traditionally, support desks operate in "Tiers", where junior support staff deal with the incoming firehose of problems, resolve the easier ones, and escalate the difficult ones up one tier. These days, they rely more and more on text classification to send incoming messages to the correct tier automatically so the issues can be resolved more quickly.

Traditionally, each of these problems would have needed thousands of hours of work to build and maintain classifiers that could assign the correct labels based on a problem-specific set of rules. With Machine Learning, we can use one algorithm to solve all of the above problems (and more!) by simply having a different dataset for each problem with which we can "train" our machines.

Practical example:

There is a public dataset of support tickets available on GitHub. You can read the information in the repository to see how Microsoft and Endava collaborated to solve a more challenging classification problem using more complicated tooling. We'll be using a simpler version of the dataset with a simplier goal in this article.

The dataset has been anonymised and stripped of any sensitive information, so some of the support tickets read a bit strangely. Here's a readable example of a ticket from the dataset:

problem with laptop monitor dear please open ticket for for problem regarding broken laptop display according forgot pencil near keyboard when tried close lid seem be physically cracked but there visible trace spot where happened attaching pictures able picture only when external monitor connected laptop otherwise possible or least manage work out we issue him laptop contingency while his laptop en maintenance thank you engineer ext en

However, due to the anonymisation process, some of them are garbled. For example:

hi change possible myself log thanks friday december pm re approved friday december pm dear owner kindly looking forward hearing thank kind regards analyst ext history reference history additional comments visible update task task assignment assignee o wner closed complete work notes description hi please these persons change name more ref msg

This makes the task more difficult, but let's see how well we can do!

Each ticket also has an associated "urgency score" of between 0 and 3, and where 0 is "very urgent" and 3 is "not urgent".

It would be useful if we could have a machine guess how urgent a ticket is, based on the description, so the urgent tickets can be resolved first. We'll do this by using some of the tickets, along with their labels, as a "training set" to teach our algorithms how to discriminate between urgent and non-urgent tickets, and using the remaining tickets as a "test set", where we hold back the labels, ask our trained algorithm to guess, and then check the labels we already have to see how well the machine did.

Our dataset

The original dataset is unbalanced as it contains far more non-urgent tickets than urgent ones. Although this is a common problem and one that needs to be thought about when using machine learning in general, we’ll take the easy route out for now and resample our dataset by taking a balanced subset of each kind of ticket.

Also, the original dataset contains many different attributes. We'll ignore everything except the body of the ticket (the text describing the problem), and the "urgency" attribute (describing how urgent the ticket is).

Our simpler dataset contains 6000 tickets, with 1500 of each urgency level, labeled 0, 1, 2, and 3. You can find the tickets in the tickets.txt file, one ticket per line, and the labels in the labels_4.txt dataset, again, one label per line and in correspondance to the tickets. There is also an even set of labels, labels_2.txt where tickets have been grouped into "urgent" (labels 0 and 1) and "not_urgent" (labels 2 and 3).

Building a text classifier in Python

Create a file and add the following code. If you're comfortable using Jupyter Notebook, it's helpful to use that to view the intermediate results as we build up our solution.

from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer

from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_predict
from sklearn.model_selection import train_test_split

from sklearn.svm import LinearSVC
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

We will need two things from scikit-learn. A Vectorizer, which transforms our text into number-based representations that are more computer-friendly as computers don't like text, and a Classifier, which is the statistical model that we'll train with our dataset. Each of these are fairly complicated concepts, which we won't cover in-depth here. For a beginner's guide to vectorization and classification, you can see my Introduction to Machine Learning article. In the above code, we've import two different classifiers — a decision tree and a support vector machine — to compare the results and two different vectorizers — a simple "Count" vectorizer and a more complicated "TF-IDF" (Term Frequency, Inverse Document Frequency) one, also to compare results.

We also imported a couple of helper functions, such as test_train_split and classification_report. The former is used to split our data into a test and training set while the latter inspects the results of our models once they are trained.

Loading and splitting our dataset

To load the dataset from the .txt files, add the following code below your imports:

with open("tickets.txt") as f:
    tickets = f.read().strip().split("\n")

with open("labels_4.txt") as f:
    labels = f.read().strip().split("\n")

We now have two arrays of 6000 elements each. The first element in the tickets array corresponds to the first item in the labels array, and so on. We want to split these into training and test datasets, using the former to train our classifier and the latter to evaluate the results.

This can be done using the train_test_split helper that we imported above by adding the following code to our file:

X_train, X_test, y_train, y_test = train_test_split(tickets, labels, test_size=0.1, random_state=1337)

We now have 5600 texts (or 90%) in X_train and the other 600 in X_test. We used test_size to define how much data we want to hold back for testing and defined random_state to make sure we get reproducible results.

Training Support Vector Machines (SVMs)

The next steps are to vectorize our dataset, train the classifier, and check the predicted results against our actual results. This can be done by adding the following code:

vectorizer = CountVectorizer()
svm = LinearSVC()
X_train = vectorizer.fit_transform(X_train)
X_test = vectorizer.transform(X_test)
_ = svm.fit(X_train, y_train)
y_pred = svm.predict(X_test)
print(classification_report(y_test, y_pred))

In the first two lines, we created a vectorizer and a classifier. Then we call fit_transform on X_train. This calculates an efficient numerical value for each word in our training set and transforms our it into sparse Bag of Words vectors, in which each text is represented by counting how often each word appears. We used the same vectorizer to transform our test set into vectors, train our classifier, and predict results on our test set. Finally, we printed a report on how well our classifier did.

There is some randomness involved in the training process, but you should see a report that's similar to this one:

              precision    recall  f1-score   support

           0       0.75      0.79      0.77       159
           1       0.53      0.52      0.52       147
           2       0.56      0.54      0.55       154
           3       0.96      0.95      0.95       140

   micro avg       0.70      0.70      0.70       600
   macro avg       0.70      0.70      0.70       600
weighted avg       0.69      0.70      0.70       600

If our classifier was simply randomly guessing the urgency of each support ticket, we would expect it to be right about 25% of the time through sheer luck. However, the report shows that it is actually right 70% of the time! The most aggregated number to look at is the f1-score column in the weighted avg row.

We can see that it is easier to identify support tickets labeled 0 (f1 score of 77%) and 3 (f1 score of 95%), while the 1s and 2s, which are tickets somewhere between "very urgent" and "not urgent", are harder to discriminate. This makes intuitive sense as the grey area between "not really urgent but sort of" and "sort of urgent but not really" is larger than that between "urgent" and "not urgent".

Let's take a look at the Confusion Matrix by adding the following line:

print(confusion_matrix(y_test, y_pred))

Which should look something like:

[[126  18  14   1]
 [ 20  75  50   2]
 [ 19  49  83   3]
 [  3   1   3 133]]

Confusion matrices can be quite confusing to interpret. In the top left hand corner, we can see that there were 126 0s we got correct in our testset, and there were 133 3s that we also predicted correctly, as seen in the bottom right. In the middle four cells, we can see that the classifier often mixed up the 2 and 3s.

The problem with our algorithm is that it is a "black box". We don't know how it works or what words it looked at to decide how urgent a specific ticket is. If it starts getting things wrong, we won't know why. Even worse, we won't know if our classifier generalises well or if it found a specific "trick" in this data set that allowed it to do well. This is similar to the famous urban legend about using machine learning to find army tanks.

Using a Decision Tree

Let's compare our results against a simpler and more interpretable classifier — the Decision Tree. Add the following code to your file:

dt = DecisionTreeClassifier()
dt.fit(X_train, y_train)
y_pred = dt.predict(X_test)
print(classification_report(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))

This does exactly what we did before, but using a Decision Tree as the classifier. My results after running this are:

              precision    recall  f1-score   support

           0       0.69      0.71      0.70       159
           1       0.50      0.43      0.46       147
           2       0.51      0.54      0.52       154
           3       0.94      0.97      0.95       140

   micro avg       0.66      0.66      0.66       600
   macro avg       0.66      0.66      0.66       600
weighted avg       0.65      0.66      0.65       600

[[113  16  27   3]
 [ 31  63  51   2]
 [ 20  47  83   4]
 [  0   1   3 136]]

We can see that these are definitely worse than our Support Vector Machine solution, but still a lot better than random guessing. The nice thing about decision trees is that we can see exactly how decisions are made. This is due to the visuaised final decision tree we produced using scikit-learn, which replies on simple tree of "True/False" decisions.

If you're using Jupyter Notebook, you can visualise the Decision Tree by adding the following code:

from IPython.display import SVG
from graphviz import Source
from IPython.display import display
graph = Source(
    tree.export_graphviz(
        dt,
        out_file=None,
        feature_names=vectorizer.get_feature_names(),
        class_names=['3', '1', '2', '0'],
        filled = True)
)
display(SVG(graph.pipe(format='svg')))

This will display the decision tree directly in your notebook.

If you're not using Jupyter, you can create a "dot" file of the tree, which is a text-based representation, with this code:

tree.export_graphviz(
        dt,
        out_file="tree.dot",
        feature_names=vectorizer.get_feature_names(),
        class_names=dt.classes_,
        filled = True
)

This will create a tree.dot file in the same directory as your Python script. You can convert it into an image with the following shell command:

dot -Tpng tree.dot > output.png

If that doesn't work, you can copy the file contents to your clipboard and paste them into an online visualiser, such as this one.

The final tree is quite large and difficult to see. For the example visualisation below, I changed this line

dt.fit(X_train, y_train)

to

dt.fit(X_train, y_train, max_depth=4)

By doing this, our tree is only allowed to have four levels at most. While this decreases the model's overall accuracy, it's a good sanity check to see what keywords the model is using. For example, we can see towards the bottom right of the tree, that november is an important word. This is probably a cause for concern, as it indicates that a lot of tickets look different from the others simply because they contain the word "november". This is likely a pattern specific to this dataset (support tickets submitted in a certain time period), rather than a rule that is generally applicable to all support tickets.

Final insights

We looked at two different algorithms to automatically classify support tickets into different tiers of Urgency. We found that Support Vector Machines work better out of the box, but Decision Trees gave us more insight into how the model worked.

If you’re trying to build a practical solution to a real world problem, you’ll probably care more about the model accuracy than the interpretability. However, that doesn’t mean that the Decision Tree isn’t worth building – it’s great to take a quick look at the rules it produces to find patterns in your dataset that you might not have known about.

Decision Trees are great for their simplicity and interpretation, but they are more limited in their power to learn complicated rules and to scale to large data sets. Support Vector Machines are often more powerful and can scale to larger data sets, but they are also more complicated and it can be difficult to know what patterns they rely on.

If you're eager to explore this problem further, there's no shortage of directions to explore! For example, you could:

  • Use the larger (original) unbalanced dataset, instead of the smaller balanced sample that we used above.
  • Try to predict factors other than "urgency". Would it be useful to predict the business category, impact, and other features listed in the original dataset?
  • Take a look at how much better our algorithms perform on the even simpler labels. Swap out abels_4.txt for the labels_2.txt file to see if it's easier to predict when we have fewer classes.
  • Use the TfidfVectorizer with character ngrams instead of words and a larger ngram_range. You should see performance improve to around 80%.
  • Find another dataset of support tickets and see if the rules that the classifier learnt from this dataset can be applied.

Of course, you can try some of the other problems we discussed at the start of this post, too. Can our model do sentiment analysis or content moderation? There are lots of text datasets that are publicly available, so find one that interests you and see whether the techniques we described here can solve other problems just as easily.

If you have questions or comments about this post, feel free to ping me on Twitter. You might also enjoy my other tutorials on Codementor.

Last updated on May 12, 2020