python, R, vimでデータマイニング

python, R, vim で疑問に思ったことなどを

cross_val_predictで足りない機能を追加

1. 足りない機能

cross_val_predictを良く使用します。
ただ足りないと思う機能が一つあります。
各FOLD毎に構築したestimatorを取得したいのですが
そのような機能がなさそうです。
仕方がないので実作の関数を作成しました。

2. cross_val_classifier

from utils4ml.sklearnwrappers import cross_val_classifier

3. 使用例

predictedとestimatorのリストを取得できます。

Code:
# %%
import numpy as np
import pandas as pd
from lightgbm import LGBMClassifier
from utils4ml.utils import load_bank_classifier
from utils4ml.sklearnwrappers import cross_val_classifier

X, y = load_bank_classifier()
y = y.cat.codes
predicted, estimators = cross_val_classifier(
    LGBMClassifier(),
    X,
    y,
)
# %%

4. estimatorsの適用

別データにestimatorsを適用します。

Code:
# %%
predicted_others = [
    e.predict_proba(X)
    for e in estimators
]
predicted_other = np.mean(
    predicted_others,
    axis=0,
)
predicted_other = pd.DataFrame(predicted_other)
predicted_other.head().to_csv(
    'output/predicted_other.csv'
)
# %%
Table 1. Result:
  0 1

0

0.9872150068879645

0.012784993112035726

1

0.9833755317309659

0.016624468269033875

2

0.9958059169918817

0.004194083008118234

3

0.9965370372207216

0.0034629627792783267

4

0.9981163874399529

0.0018836125600472052