代码拉取完成,页面将自动刷新
同步操作将从 脱线/faiss_dog_cat_question 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier, StackingClassifier, BaggingClassifier, AdaBoostClassifier
from sklearn.metrics import accuracy_score
from sklearn.base import clone
from FaissKNeighbors import FaissKNeighbors
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import ParameterGrid
from util import createXY # 确保这个函数在同一个目录下或者正确安装了对应的包
import logging
from joblib import dump
import time
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 加载数据
def load_data():
train_folder = "data/train" # 训练数据文件夹路径
dest_folder = "." # 特征和标签保存的目标文件夹
method = 'flat' # 使用 flat 方法
X, y = createXY(train_folder, dest_folder, method=method)
return X, y
# 训练单个模型并返回最佳模型
def train_single_model(model, X_train, y_train, X_test, y_test, param_grid):
best_score = 0
best_model = None
# 使用 ParameterGrid 来迭代所有参数组合
for params in ParameterGrid(param_grid):
model.set_params(**params)
model.fit(X_train, y_train)
predictions = model.predict(X_test)
score = accuracy_score(y_test, predictions)
if score > best_score:
best_score = score
best_model = clone(model)
best_model.set_params(**params)
return best_model, best_score
# 训练集成学习模型
def train_ensemble_models(X, y):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=2023)
# 定义参数网格
param_grid_knn = {'n_neighbors': [3, 5, 7, 9]}
param_grid_rf = {'n_estimators': [10, 50, 100], 'max_depth': [None, 10, 20, 30]}
param_grid_gb = {'n_estimators': [10, 50, 100], 'learning_rate': [0.01, 0.1, 0.2]}
# 训练 KNeighborsClassifier
knn, knn_score = train_single_model(KNeighborsClassifier(), X_train, y_train, X_test, y_test, param_grid_knn)
logging.info(f"KNN Best Score: {knn_score}")
# 训练 RandomForestClassifier
rf, rf_score = train_single_model(RandomForestClassifier(), X_train, y_train, X_test, y_test, param_grid_rf)
logging.info(f"Random Forest Best Score: {rf_score}")
# 训练 GradientBoostingClassifier
gb, gb_score = train_single_model(GradientBoostingClassifier(), X_train, y_train, X_test, y_test, param_grid_gb)
logging.info(f"Gradient Boosting Best Score: {gb_score}")
# 训练 VotingClassifier (hard voting)
start_time = time.time()
voting_clf_hard = VotingClassifier(estimators=[
('knn', knn), ('rf', rf), ('gb', gb)], voting='hard')
voting_clf_hard.fit(X_train, y_train)
voting_score_hard = accuracy_score(y_test, voting_clf_hard.predict(X_test))
logging.info(f"Hard Voting Classifier Best Score: {voting_score_hard}")
logging.info(f"Hard Voting Classifier Training Time: {time.time() - start_time} seconds")
# 训练 VotingClassifier (soft voting)
start_time = time.time()
voting_clf_soft = VotingClassifier(estimators=[
('knn', knn), ('rf', rf), ('gb', gb)], voting='soft')
voting_clf_soft.fit(X_train, y_train)
voting_score_soft = accuracy_score(y_test, voting_clf_soft.predict(X_test))
logging.info(f"Soft Voting Classifier Best Score: {voting_score_soft}")
logging.info(f"Soft Voting Classifier Training Time: {time.time() - start_time} seconds")
# 训练 StackingClassifier
start_time = time.time()
stacking_clf = StackingClassifier(estimators=[
('knn', knn), ('rf', rf), ('gb', gb)], final_estimator=LogisticRegression())
stacking_clf.fit(X_train, y_train)
stacking_score = accuracy_score(y_test, stacking_clf.predict(X_test))
logging.info(f"Stacking Classifier Best Score: {stacking_score}")
logging.info(f"Stacking Classifier Training Time: {time.time() - start_time} seconds")
# 训练 BaggingClassifier
start_time = time.time()
bagging_clf = BaggingClassifier(base_estimator=KNeighborsClassifier(), n_estimators=10, random_state=2023)
bagging_clf.fit(X_train, y_train)
bagging_score = accuracy_score(y_test, bagging_clf.predict(X_test))
logging.info(f"Bagging Classifier Best Score: {bagging_score}")
logging.info(f"Bagging Classifier Training Time: {time.time() - start_time} seconds")
# 训练 PastingClassifier (使用 BaggingClassifier 并设置 bootstrap=False)
start_time = time.time()
pasting_clf = BaggingClassifier(base_estimator=KNeighborsClassifier(), n_estimators=10, bootstrap=False, random_state=2023)
pasting_clf.fit(X_train, y_train)
pasting_score = accuracy_score(y_test, pasting_clf.predict(X_test))
logging.info(f"Pasting Classifier Best Score: {pasting_score}")
logging.info(f"Pasting Classifier Training Time: {time.time() - start_time} seconds")
# 训练 AdaBoostClassifier
start_time = time.time()
adaboost_clf = AdaBoostClassifier(n_estimators=50, random_state=2023)
adaboost_clf.fit(X_train, y_train)
adaboost_score = accuracy_score(y_test, adaboost_clf.predict(X_test))
logging.info(f"AdaBoost Classifier Best Score: {adaboost_score}")
logging.info(f"AdaBoost Classifier Training Time: {time.time() - start_time} seconds")
# 训练 GradientBoostingClassifier
start_time = time.time()
gradient_boost_clf = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, random_state=2023)
gradient_boost_clf.fit(X_train, y_train)
gradient_boost_score = accuracy_score(y_test, gradient_boost_clf.predict(X_test))
logging.info(f"Gradient Boosting Classifier Best Score: {gradient_boost_score}")
logging.info(f"Gradient Boosting Classifier Training Time: {time.time() - start_time} seconds")
# 保存最佳模型
best_model = max(
(knn, knn_score),
(rf, rf_score),
(gb, gb_score),
(voting_clf_hard, voting_score_hard),
(voting_clf_soft, voting_score_soft),
(stacking_clf, stacking_score),
(bagging_clf, bagging_score),
(pasting_clf, pasting_score),
(adaboost_clf, adaboost_score),
(gradient_boost_clf, gradient_boost_score),
key=lambda x: x[1]
)
logging.info(f"Best Model: {best_model[0]}")
logging.info(f"Best Score: {best_model[1]}")
model_filename = "best_model.joblib" # 更改文件扩展名为 .joblib
dump(best_model[0], model_filename) # 使用 joblib.dump 保存模型
return best_model
# 主函数
def main():
X, y = load_data()
best_model = train_ensemble_models(X, y)
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。