Spark & Python: MLlib Decision Trees

Published Jul 03, 2015Last updated Feb 15, 2017


My Spark & Python series of tutorials can be examined individually, although there is a more or less linear 'story' when followed in sequence. By using the same dataset they try to solve a related set of tasks with it.

It is not the only one but, a good way of following these Spark tutorials is by first cloning the GitHub repo, and then starting your own IPython notebook in pySpark mode. For example, if we have a standalone Spark installation running in our localhost with a maximum of 6Gb per node assigned to IPython:

MASTER="spark://" SPARK_EXECUTOR_MEMORY="6G" IPYTHON_OPTS="notebook --pylab inline" ~/spark-1.3.1-bin-hadoop2.6/bin/pyspark

Notice that the path to the pyspark command will depend on your specific installation. So as a requirement, you need to have Spark installed in the same machine you are going to start the IPython notebook server.

For more Spark options see here. In general it works the rule of passign options described in the form spark.executor.memory as SPARK_EXECUTOR_MEMORY when calling IPython/pySpark.


We will be using datasets from the KDD Cup 1999.


The reference book for these and other Spark related topics is Learning Spark by Holden Karau, Andy Konwinski, Patrick Wendell, and Matei Zaharia.

The KDD Cup 1999 competition dataset is described in detail here.


In this tutorial we will use Spark's machine learning library MLlib to build a Decision Tree classifier for network attack detection. We will use the complete KDD Cup 1999 datasets to test Spark capabilities with large datasets.

Decision trees are a popular machine learning tool in part because they are easy to interpret, handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. In this notebook, we will first train a classification tree including every single predictor. Then we will use our results to perform model selection. Once we find out the most important ones (the main splits in the tree) we will build a minimal tree using just three of them (the first two levels of the tree in order to compare performance and accuracy.

For your reference when showing execution times, at the time of processing this tutorial, our Spark cluster contains:

  • Eight nodes, with one of them acting as master and the rest as workers (for a total of 14 execution cores available).
  • Each node contains 8Gb of RAM, with 6Gb being used for each node.
  • Each node has a 2.4Ghz Intel dual core processor.

Getting the Data and Creating the RDD

As we said, this time we will use the complete dataset provided for the KDD Cup 1999, containing nearly half million nework interactions. The file is provided as a Gzip file that we will download locally.

import urllib

f = urllib.urlretrieve ("", "")

data_file = "./"
raw_data = sc.textFile(data_file)

print "Train data size is {}".format(raw_data.count())
Train data size is 4898431

The KDD Cup 1999 also provide test data that we will load in a separate RDD.

ft = urllib.urlretrieve("", "corrected.gz")

test_data_file = "./corrected.gz"
test_raw_data = sc.textFile(test_data_file)

print "Test data size is {}".format(test_raw_data.count())
Test data size is 311029

Detecting network attacks using Decision Trees

In this section we will train a classification tree that, as we did with logistic regression, will predict if a network interaction is either normal or attack.

Training a classification tree using MLlib requires some parameters:

  • Training data
  • Num classes
  • Categorical features info: a map from column to categorical variables arity. This is optional, although it should increase model accuracy. However it requires that we know the levels in our categorical variables in advance. second we need to parse our data to convert labels to integer values within the arity range.
  • Impurity metric
  • Tree maximum depth
  • And tree maximum number of bins

In the next section we will see how to obtain all the labels within a dataset and convert them to numerical factors.

Preparing the Data

As we said, in order to benefits from trees hability to seamlessly with categporical variables, we need to convert them to numerical factors. But first we need to obtain all the possible levels. We will use set transformations on a csv parsed RDD.

from pyspark.mllib.regression import LabeledPoint
from numpy import array

csv_data = x: x.split(","))
test_csv_data = x: x.split(","))

protocols = x: x[1]).distinct().collect()
services = x: x[2]).distinct().collect()
flags = x: x[3]).distinct().collect()

And now we can use this Python lists in our create_labeled_point function. If a factor level is not in the training data, we assign an especial level. Remember that we cannot use testing data for training our model, not even the factor levels. The testing data represents the unknown to us in a real case.

def create_labeled_point(line_split):
    # leave_out = [41]
    clean_line_split = line_split[0:41]
    # convert protocol to numeric categorical variable
        clean_line_split[1] = protocols.index(clean_line_split[1])
        clean_line_split[1] = len(protocols)
    # convert service to numeric categorical variable
        clean_line_split[2] = services.index(clean_line_split[2])
        clean_line_split[2] = len(services)
    # convert flag to numeric categorical variable
        clean_line_split[3] = flags.index(clean_line_split[3])
        clean_line_split[3] = len(flags)
    # convert label to binary label
    attack = 1.0
    if line_split[41]=='normal.':
        attack = 0.0
    return LabeledPoint(attack, array([float(x) for x in clean_line_split]))

training_data =
test_data =

Training a Classifier

We are now ready to train our classification tree. We will keep the maxDepth value small. This will lead to smaller accuracy, but we will obtain less splits so later on we can better interpret the tree. In a production system we will try to increase this value in order to find a better accuracy.

from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from time import time

# Build the model
t0 = time()
tree_model = DecisionTree.trainClassifier(training_data, numClasses=2, 
                                          categoricalFeaturesInfo={1: len(protocols), 2: len(services), 3: len(flags)},
                                          impurity='gini', maxDepth=4, maxBins=100)
tt = time() - t0

print "Classifier trained in {} seconds".format(round(tt,3))
Classifier trained in 439.971 seconds

Evaluating the Model

In order to measure the classification error on our test data, we use map on the test_data RDD and the model to predict each test point class.

predictions = tree_model.predict( p: p.features))
labels_and_preds = p: p.label).zip(predictions)

Classification results are returned in pars, with the actual test label and the predicted one. This is used to calculate the classification error by using filter and count as follows.

t0 = time()
test_accuracy = labels_and_preds.filter(lambda (v, p): v == p).count() / float(test_data.count())
tt = time() - t0

print "Prediction made in {} seconds. Test accuracy is {}".format(round(tt,3), round(test_accuracy,4))
Prediction made in 38.651 seconds. Test accuracy is 0.9155

NOTE: the zip transformation doesn't work properly with pySpark 1.2.1. It does in 1.3

Interpreting the Model

Understanding our tree splits is a great excersise in order to explain our classification labels in terms of predictors and the values they take. Using the toDebugString method in our three model we can obtain a lot of information regarding splits, nodes, etc.

print "Learned classification tree model:"
print tree_model.toDebugString()

Learned classification tree model:

DecisionTreeModel classifier of depth 4 with 27 nodes
  If (feature 22 <= 89.0)
   If (feature 3 in {2.0,3.0,4.0,7.0,9.0,10.0})
    If (feature 36 <= 0.43)
     If (feature 28 <= 0.19)
      Predict: 1.0
     Else (feature 28 > 0.19)
      Predict: 0.0
    Else (feature 36 > 0.43)
     If (feature 2 in {0.0,3.0,15.0,26.0,27.0,36.0,42.0,58.0,67.0})
      Predict: 0.0
     Else (feature 2 not in {0.0,3.0,15.0,26.0,27.0,36.0,42.0,58.0,67.0})
      Predict: 1.0
   Else (feature 3 not in {2.0,3.0,4.0,7.0,9.0,10.0})
    If (feature 2 in {50.0,51.0})
     Predict: 0.0
    Else (feature 2 not in {50.0,51.0})
     If (feature 32 <= 168.0)
      Predict: 1.0
     Else (feature 32 > 168.0)
      Predict: 0.0
  Else (feature 22 > 89.0)
   If (feature 5 <= 0.0)
    If (feature 11 <= 0.0)
     If (feature 31 <= 253.0)
      Predict: 1.0
     Else (feature 31 > 253.0)
      Predict: 1.0
    Else (feature 11 > 0.0)
     If (feature 2 in {12.0})
      Predict: 0.0
     Else (feature 2 not in {12.0})
      Predict: 1.0
   Else (feature 5 > 0.0)
    If (feature 29 <= 0.08)
     If (feature 2 in {3.0,4.0,26.0,36.0,42.0,58.0,68.0})
      Predict: 0.0
     Else (feature 2 not in {3.0,4.0,26.0,36.0,42.0,58.0,68.0})
      Predict: 1.0
    Else (feature 29 > 0.08)
     Predict: 1.0

For example, a network interaction with the following features (see description here) will be classified as an attack by our model:

  • count, the number of connections to the same host as the current connection in the past two seconds, being greater than 32.
  • dst_bytes, the number of data bytes from destination to source, is 0.
  • service is neither level 0 nor 52.
  • logged_in is false.

From our services list we know that:

print "Service 0 is {}".format(services[0])
print "Service 52 is {}".format(services[52])
Service 0 is urp_i  
Service 52 is tftp_u  

So we can caracterise network interactions with more than 32 connections to the same server in the last 2 seconds, transferring zero bytes from destination to source, where service is neither urp_i nor tftp_u, and not logged in, as network attacks. A similar approach can be used for each tree terminal node.

We can see that count is the first node split in the tree. Remember that each partition is chosen greedily by selecting the best split from a set of possible splits, in order to maximize the information gain at a tree node (see more here). At a second level we find variables flag (normal or error status of the connection) and dst_bytes (the number of data bytes from destination to source) and so on.

This explaining capability of a classification (or regression) tree is one of its main benefits. Understaining data is a key factor to build better models.

Building a Minimal Model Using the Three Main Splits

So now that we know the main features predicting a network attack, thanks to our classification tree splits, let's use them to build a minimal classification tree with just the main three variables: count, dst_bytes, and flag.

We need to define the appropriate function to create labeled points.

def create_labeled_point_minimal(line_split):
    # leave_out = [41]
    clean_line_split = line_split[3:4] + line_split[5:6] + line_split[22:23]
    # convert flag to numeric categorical variable
        clean_line_split[0] = flags.index(clean_line_split[0])
        clean_line_split[0] = len(flags)
    # convert label to binary label
    attack = 1.0
    if line_split[41]=='normal.':
        attack = 0.0
    return LabeledPoint(attack, array([float(x) for x in clean_line_split]))

training_data_minimal =
test_data_minimal =

That we use to train the model.

# Build the model
t0 = time()
tree_model_minimal = DecisionTree.trainClassifier(
    training_data_minimal, numClasses=2, 
    categoricalFeaturesInfo={0: len(flags)},
    impurity='gini', maxDepth=3, maxBins=32)
tt = time() - t0

print "Classifier trained in {} seconds".format(round(tt,3))
Classifier trained in 226.519 seconds

Now we can predict on the testing data and calculate accuracy.

predictions_minimal = tree_model_minimal.predict( p: p.features))
labels_and_preds_minimal = p: p.label).zip(predictions_minimal)

t0 = time()
test_accuracy = labels_and_preds_minimal.filter(lambda (v, p): v == p).count() / float(test_data_minimal.count())
tt = time() - t0

print "Prediction made in {} seconds. Test accuracy is {}".format(round(tt,3), round(test_accuracy,4))
Prediction made in 23.202 seconds. Test accuracy is 0.909

So we have trained a classification tree with just the three most important predictors, in half of the time, and with a not so bad accuracy. In fact, a classification tree is a very good model selection tool!

Discover and read more posts from Jose A Dianes
get started