程序员文章、书籍推荐和程序员创业信息与资源分享平台

网站首页 > 技术文章 正文

决策树算法原理与Python实现

hfteth 2025-08-05 18:21:09 技术文章 4 ℃

决策树作为机器学习中基础且应用广泛的分类算法,以其直观的树形结构和强大的解释性备受青睐。本文将从理论基础出发,结合Python代码实现,深入剖析决策树的核心原理与实战应用。

决策树基础概念

决策树是一种树形结构的分类模型,由内部节点和叶节点组成:

  • 内部节点表示特征判断
  • 叶节点表示分类结果

其工作原理类似"二十个问题"游戏,通过层层特征筛选缩小分类范围,最终得到实例的类别归属。在邮件分类场景中,决策树会先判断发件域名,再根据内容关键词进一步分类,展现出清晰的层级决策逻辑。

决策树核心数学原理

信息熵与信息增益




决策树构建算法

决策树通过递归方式构建,核心逻辑如下:

def createBranch():
    ''' 决策树递归构建逻辑 '''
    if 所有数据分类标签相同:
        return 类标签
    else:
        选择信息增益最大的特征
        划分数据集
        为每个划分子集递归创建分支
        return 分支节点

决策树核心代码实现

基础函数实现

计算香农熵

def calcShannonEnt(dataSet):
    """计算数据集的香农熵"""
    numEntries = len(dataSet)
    labelCounts = {}
    
    # 统计各类标签出现次数
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts:
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    
    # 计算香农熵
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * math.log(prob, 2)
    return shannonEnt

数据集划分函数

def splitDataSet(dataSet, index, value):
    """根据特征和值划分数据集"""
    retDataSet = []
    for featVec in dataSet:
        if featVec[index] == value:
            # 提取划分后的数据(排除当前特征列)
            reducedFeatVec = featVec[:index]
            reducedFeatVec.extend(featVec[index+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

最优特征选择

def chooseBestFeatureToSplit(dataSet):
    """选择信息增益最大的特征"""
    numFeatures = len(dataSet[0]) - 1  # 特征数量
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain, bestFeature = 0.0, -1
    
    # 遍历所有特征计算信息增益
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        
        # 计算当前特征划分后的熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
            
        # 计算信息增益并更新最优特征
        infoGain = baseEntropy - newEntropy
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

决策树构建与分类

递归创建决策树

def createTree(dataSet, labels):
    """递归构建决策树"""
    classList = [example[-1] for example in dataSet]
    
    # 停止条件1:所有标签相同
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    
    # 停止条件2:所有特征使用完毕
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    
    # 选择最优特征并构建树
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    
    # 递归处理每个划分子集
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(
            splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

使用决策树进行分类

def classify(inputTree, featLabels, testVec):
    """使用决策树对测试数据分类"""
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    
    # 递归遍历决策树
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict):
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else:
        classLabel = valueOfFeat
    return classLabel

决策树项目实战

案例1:鱼类与非鱼类分类

数据准备

def createDataSet():
    """创建鱼类分类数据集"""
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

构建与测试流程

  1. 收集数据:使用createDataSet()生成样本
  2. 准备数据:由于数据已离散化,无需额外处理
  3. 训练模型:调用createTree构建决策树
  4. 测试模型:使用classify对新数据分类

案例2:隐形眼镜类型预测

数据解析

# 解析文本文件数据
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']

决策树存储与加载

def storeTree(inputTree, filename):
    """存储决策树到文件"""
    import pickle
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()

def grabTree(filename):
    """从文件加载决策树"""
    import pickle
    fr = open(filename, 'rb')
    return pickle.load(fr)

决策树算法特点

优点

  • 计算复杂度低,适合大规模数据
  • 分类结果直观易懂,具有天然可解释性
  • 支持处理缺失值和不相关特征

缺点

  • 容易产生过拟合,需通过剪枝优化
  • 对高度非线性数据分类效果有限

适用场景

  • 标称型数据(分类数据)和离散化的数值型数据
  • 需要解释分类逻辑的场景
  • 数据预处理要求较低的快速建模场景

通过以上代码实现和项目案例,我们可以清晰理解决策树从理论到实践的完整流程。作为基础分类算法,决策树不仅是机器学习入门的重要内容,也是理解集成学习算法(如随机森林)的基础,在数据挖掘领域具有不可替代的地位。

最近发表
标签列表