mechine_learning/random_forest_classification.py

41 lines
1.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 导入所需库
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
# 1. 加载数据集(这里使用鸢尾花数据集)
iris = load_iris()
X = iris.data
y = iris.target
# 2. 划分训练集和测试集80%训练20%测试)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 3. 创建随机森林分类器
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)
# 4. 训练模型
rf_classifier.fit(X_train, y_train)
# 5. 在测试集上进行预测
y_pred = rf_classifier.predict(X_test)
# 6. 评估模型性能
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率:{accuracy:.4f}")
# 输出详细的分类报告精确率、召回率、F1分数等
print("分类报告:")
print(classification_report(y_test, y_pred))
# 7. 可视化特征重要性(可选)
feature_importances = rf_classifier.feature_importances_
plt.figure(figsize=(10, 6))
plt.barh(iris.feature_names, feature_importances)
plt.xlabel("feature importance")
plt.title("random forest model feature importance")
plt.savefig('./output/random_forest_c.png')