41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
# 导入所需库
|
||
from sklearn.ensemble import RandomForestClassifier
|
||
from sklearn.datasets import load_iris
|
||
from sklearn.model_selection import train_test_split
|
||
from sklearn.metrics import accuracy_score, classification_report
|
||
import matplotlib.pyplot as plt
|
||
|
||
# 1. 加载数据集(这里使用鸢尾花数据集)
|
||
iris = load_iris()
|
||
X = iris.data
|
||
y = iris.target
|
||
|
||
# 2. 划分训练集和测试集(80%训练,20%测试)
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||
|
||
# 3. 创建随机森林分类器
|
||
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)
|
||
|
||
# 4. 训练模型
|
||
rf_classifier.fit(X_train, y_train)
|
||
|
||
# 5. 在测试集上进行预测
|
||
y_pred = rf_classifier.predict(X_test)
|
||
|
||
# 6. 评估模型性能
|
||
accuracy = accuracy_score(y_test, y_pred)
|
||
print(f"模型准确率:{accuracy:.4f}")
|
||
|
||
# 输出详细的分类报告(精确率、召回率、F1分数等)
|
||
print("分类报告:")
|
||
print(classification_report(y_test, y_pred))
|
||
|
||
# 7. 可视化特征重要性(可选)
|
||
feature_importances = rf_classifier.feature_importances_
|
||
|
||
plt.figure(figsize=(10, 6))
|
||
plt.barh(iris.feature_names, feature_importances)
|
||
plt.xlabel("feature importance")
|
||
plt.title("random forest model feature importance")
|
||
plt.savefig('./output/random_forest_c.png')
|