from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn import tree import matplotlib.pyplot as plt # 加载数据集 iris = load_iris() X, y = iris.data, iris.target # 划分数据集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 创建决策树模型 clf = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42) # 训练模型 clf.fit(X_train, y_train) # 预测 y_pred = clf.predict(X_test) # 计算准确率 accuracy = clf.score(X_test, y_test) print(f"模型准确率: {accuracy:.2f}") # 可视化决策树 plt.figure(figsize=(12,8)) tree.plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names) plt.savefig('./output/dicision_tree_c.png')