31 lines
846 B
Python
31 lines
846 B
Python
from sklearn.tree import DecisionTreeClassifier
|
|
from sklearn.datasets import load_iris
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn import tree
|
|
import matplotlib.pyplot as plt
|
|
|
|
# 加载数据集
|
|
iris = load_iris()
|
|
X, y = iris.data, iris.target
|
|
|
|
# 划分数据集
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
|
|
|
# 创建决策树模型
|
|
clf = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
|
|
|
|
# 训练模型
|
|
clf.fit(X_train, y_train)
|
|
|
|
# 预测
|
|
y_pred = clf.predict(X_test)
|
|
|
|
# 计算准确率
|
|
accuracy = clf.score(X_test, y_test)
|
|
print(f"模型准确率: {accuracy:.2f}")
|
|
|
|
# 可视化决策树
|
|
plt.figure(figsize=(12,8))
|
|
tree.plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
|
|
plt.savefig('./output/dicision_tree_c.png')
|