| def plotNode(nodeTxt, centerPt, parentPt, nodeType): #绘制带箭头的注解 #annotate参数:nodeTxt:标注文本,xy:所要标注的位置坐标,xytext:标注文本所在位置,arrowprops:标注箭头属性信息 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) def plotMidText(cntrPt, parentPt, txtString): #在父子节点间填充文本信息 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt): #if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) #计算宽与高 depth = getTreeDepth(myTree) firstStr = myTree.keys()[0] #the text label for this node should be this print plotTree.xOff cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) print parentPt print cntrPt plotMidText(cntrPt, parentPt, nodeTxt) #标记子节点属性值 plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #减少y偏移 for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': #test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #if you do get a dictonary you know it's a tree, and the first element will be another dict
def createPlot(inTree): #绘制树形图,调用了plotTree() fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) #存储树的宽度 plotTree.totalD = float(getTreeDepth(inTree)) #存储树的深度 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show()