1#
2# Licensed to the Apache Software Foundation (ASF) under one or more
3# contributor license agreements.  See the NOTICE file distributed with
4# this work for additional information regarding copyright ownership.
5# The ASF licenses this file to You under the Apache License, Version 2.0
6# (the "License"); you may not use this file except in compliance with
7# the License.  You may obtain a copy of the License at
8#
9#    http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""
19Decision Tree Classification Example.
20"""
21from __future__ import print_function
22
23from pyspark import SparkContext
24# $example on$
25from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
26from pyspark.mllib.util import MLUtils
27# $example off$
28
29if __name__ == "__main__":
30
31    sc = SparkContext(appName="PythonDecisionTreeClassificationExample")
32
33    # $example on$
34    # Load and parse the data file into an RDD of LabeledPoint.
35    data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
36    # Split the data into training and test sets (30% held out for testing)
37    (trainingData, testData) = data.randomSplit([0.7, 0.3])
38
39    # Train a DecisionTree model.
40    #  Empty categoricalFeaturesInfo indicates all features are continuous.
41    model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
42                                         impurity='gini', maxDepth=5, maxBins=32)
43
44    # Evaluate model on test instances and compute test error
45    predictions = model.predict(testData.map(lambda x: x.features))
46    labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
47    testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
48    print('Test Error = ' + str(testErr))
49    print('Learned classification tree model:')
50    print(model.toDebugString())
51
52    # Save and load model
53    model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
54    sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
55    # $example off$
56