python

しょっちゅう忘れることを書いておく。

33

188 views

GridSearch

# coding: UTF-8
import os
import numpy as np
from catboost import CatBoost, CatBoostRegressor, CatBoostClassifier
from catboost import Pool
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_curve, roc_auc_score
from sklearn.metrics import mean_squared_error
from lightgbm import LGBMRegressor, LGBMClassifier
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer, roc_curve
from util import read_csv


def custom_auc(ground_truth, predictions):
    fpr, tpr, _ = roc_curve(ground_truth, predictions[:, 1], pos_label='Assessed')
    return auc(fpr, tpr)



def create_train_data():
    records = read_csv(r"data/train.csv")
    dataset = []
    labels = []

    for record in records:
        dataset.append(record[0:-1])
        labels.append(record[-1])

    # 正規化する
    scaler  = MinMaxScaler(feature_range=(0, 1), copy=True)

    # 学習データの正規化
    scaler.fit(dataset)
    dataset = scaler.transform(dataset)


    return dataset, labels


def load_test_data():
    test_data = read_csv(r"data/test.csv")
    # 正規化する
    scaler  = MinMaxScaler(feature_range=(0, 1), copy=True)

    # 学習データの正規化
    scaler.fit(test_data)
    test_data = scaler.transform(test_data)


    return test_data



def write_result_csv(file_name, y_pred):
    if not os.path.exists("results"):
        os.makedirs("results")

    with open(os.path.join("results", file_name), "w") as f:
        for i, y in enumerate(y_pred):
            if y < 0.:
                y = 0.
            if y > 1.:
                y = 1.
            f.write(f"{i},{y:.6f}\n")



def grid_random_forest_classifier(train_dataset, train_labels, test_data):
    """
    ランダムフォレストによる2値分類
    """

    model = RandomForestClassifier()
    my_auc = make_scorer(custom_auc, greater_is_better=True, needs_proba=True)
    model_reg = GridSearchCV(model, {'max_depth': [8,10,12,16,32], 'n_estimators': [100,200,300,500]}, verbose=2, refit='AUC', scoring='roc_auc')

    model_reg.fit(train_dataset, train_labels)
    print("{},{}".format(model_reg.best_params_, model_reg.best_score_))
    print('Train score: {:.3f}'.format(model_reg.score(train_dataset, train_labels)))
    print("学習が終了しました")

    # 予測データの作成
    test_prob = model_reg.predict_proba(test_data)[:, 1]
    write_result_csv("grid_random_forest_classifier_result.csv", test_prob)




if __name__ == '__main__':

    train_dataset, train_labels = create_train_data()
    test_data = load_test_data()
    grid_random_forest_classifier(train_dataset, train_labels, test_data)


スタッキング

スタッキングは今まで自力実装していたが、sklearnにライブラリがあるらしい。
https://qiita.com/maskot1977/items/de7383898123fa378d86がわかりやすい。

Page 45 of 56.

前のページ 次のページ



[添付ファイル]


お問い合わせ

プロフィール

マッスル

自己紹介

本サイトの作成者。
趣味:プログラム/水耕栽培/仮想通貨/激辛好き
プログラムは趣味と勉強を兼ねて、のんびり本サイトを作っています。
フレームワークはdjango。
仮想通貨はNEMが好き。
水耕栽培は激辛好きが高じて、キャロライナ・リーパーの栽培にチャレンジ中。

サイト/ブログ

https://www.osumoi-stdio.com/pyarticle/

ツイッター

@darkimpact0626