Another parameter, another set of quirks!
min_samples_leaf
is sort of similar to max_depth
. It helps us avoid overfitting. It's also non-obvious what you should use as your upper and lower limits to search between. Let's do what we did last week - build a forest with no parameters, see what it does, and use the upper and lower limits!
import pandas as pd
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
data = load_breast_cancer()
X, y = data.data, data.target
rfArgs = {"random_state": 0,
"n_jobs": -1,
"class_weight": "balanced",
"n_estimators": 18,
"oob_score": True}
clf = RandomForestClassifier(**rfArgs)
clf.fit(X, y)
Let's use the handy function from here to crawl the number of samples in a tree's leaf nodes:
def leaf_samples(tree, node_id = 0):
left_child = tree.children_left[node_id]
right_child = tree.children_right[node_id]
if left_child == _tree.TREE_LEAF:
samples = np.array([tree.n_node_samples[node_id]])
else:
left_samples = leaf_samples(tree, left_child)
right_samples = leaf_samples(tree, right_child)
samples = np.append(left_samples, right_samples)
return samples
Last week we made a function to grab them for a whole forest - since this is the second time we're doing this, and we may do it again, let's make a modular little function that takes a crawler function as an argument!
def getForestParams(X, y, param, kwargs):
clf = RandomForestClassifier(**kwargs)
clf.fit(X, y)
params = np.hstack([param(estimator.tree_)
for estimator in clf.estimators_])
return {"min": params.min(),
"max": params.max()}
def getForestParams(X, y, param, kwargs):
clf = RandomForestClassifier(**kwargs)
clf.fit(X, y)
params = np.hstack([param(estimator.tree_)
for estimator in clf.estimators_])
return {"min": params.min(),
"max": params.max()}
Let's see it in action!
data = load_breast_cancer()
X, y = data.data, data.target
rfArgs = {"random_state": 0,
"n_jobs": -1,
"class_weight": "balanced",
"n_estimators": 18,
"oob_score": True}
getForestParams(X, y, leaf_samples, rfArgs)
#> {'max': 199, 'min': 1}
Almost ready to start optimizing! Since part of what we get out of optimizing min_samples_leaf
is regularization (and because it's just good practice!), let's make a metric with some cross-validation. Luckily, Scikit has a builtin cross_val_score
function. We'll just need to do a teensy bit of tweaking to make it use the area under a precision_recall_curve
.
from sklearn.model_selection import cross_val_score
def auc_prc(estimator, X, y):
estimator.fit(X, y)
y_pred = estimator.oob_decision_function_[:, 1]
precision, recall, _ = precision_recall_curve(y, y_pred)
return auc(recall, precision)
def getForestAccuracyCV(X, y, kwargs):
clf = RandomForestClassifier(**kwargs)
return np.mean(cross_val_score(clf, X, y, scoring=auc_prc, cv=5))
Awesome, now we have a metric that can be fed into our binary search.
min_samples_leaf = bgs.compareValsBaseCase(X,
y,
getForestAccuracyCV,
rfArgs,
"min_samples_leaf",
0,
1,
199)
bgs.showTimeScoreChartAndGraph(min_samples_leaf)
min_samples_leaf | score | time |
---|---|---|
1 | 0.981662 | 1.402102 |
199 | 0.506455 | 1.416349 |
100 | 0.506455 | 1.401090 |
51 | 0.506455 | 1.394548 |
26 | 0.975894 | 1.396503 |
14 | 0.982954 | 1.398522 |
7 | 0.979888 | 1.398929 |
10 | 0.984789 | 1.404815 |
12 | 0.986302 | 1.391171 |
min_samples_leaf | score | time | scoreTimeRatio |
---|---|---|---|
1 | 0.992414 | 0.473848 | 0.082938 |
199 | 0.002084 | 1.039718 | 0.000000 |
100 | 0.002084 | 0.433676 | 0.000111 |
51 | 0.002084 | 0.173824 | 0.000396 |
26 | 0.980393 | 0.251484 | 0.154448 |
14 | 0.995105 | 0.331692 | 0.118839 |
7 | 0.988716 | 0.347858 | 0.112585 |
10 | 0.998930 | 0.581632 | 0.067998 |
12 | 1.002084 | 0.039718 | 1.000000 |
Looks like the action's between 1 and 51. More than that, and the score goes while simultaneously increasing the runtime - the opposite of what we want!
min_samples_leaf = bgs.compareValsBaseCase(X,
y,
getForestAccuracyCV,
rfArgs,
"min_samples_leaf",
0,
1,
14)
bgs.showTimeScoreChartAndGraph(min_samples_leaf)
min_samples_leaf | score | time |
---|---|---|
1 | 0.981662 | 1.389387 |
51 | 0.506455 | 1.403807 |
26 | 0.975894 | 1.404517 |
14 | 0.982954 | 1.385420 |
7 | 0.979888 | 1.398840 |
10 | 0.984789 | 1.393863 |
12 | 0.986302 | 1.411774 |
min_samples_leaf | score | time | scoreTimeRatio |
---|---|---|---|
1 | 0.992414 | 0.188492 | 0.200671 |
51 | 0.002084 | 0.735618 | 0.000000 |
26 | 0.980393 | 0.762561 | 0.048920 |
14 | 0.995105 | 0.037944 | 1.000000 |
7 | 0.988716 | 0.547179 | 0.068798 |
10 | 0.998930 | 0.358303 | 0.106209 |
12 | 1.002084 | 1.037944 | 0.036709 |
Big drop-off after 26, it seems!
min_samples_leaf = bgs.compareValsBaseCase(X,
y,
getForestAccuracyCV,
rfArgs,
"min_samples_leaf",
0,
1,
26)
bgs.showTimeScoreChartAndGraph(min_samples_leaf)
min_samples_leaf | score | time |
---|---|---|
1 | 0.981662 | 1.407957 |
26 | 0.975894 | 1.398042 |
14 | 0.982954 | 1.396782 |
7 | 0.979888 | 1.396096 |
10 | 0.984789 | 1.402322 |
12 | 0.986302 | 1.401080 |
min_samples_leaf | score | time | scoreTimeRatio |
---|---|---|---|
1 | 0.650270 | 1.084306 | 0.040144 |
26 | 0.096077 | 0.248406 | 0.000000 |
14 | 0.774346 | 0.142157 | 0.954016 |
7 | 0.479788 | 0.084306 | 1.000000 |
10 | 0.950677 | 0.609184 | 0.221294 |
12 | 1.096077 | 0.504512 | 0.336668 |
One more with 14 as our upper limit!
min_samples_leaf = bgs.compareValsBaseCase(X,
y,
getForestAccuracyCV,
rfArgs,
"min_samples_leaf",
0,
1,
14)
bgs.showTimeScoreChartAndGraph(min_samples_leaf)
min_samples_leaf | score | time |
---|---|---|
1 | 0.981662 | 1.401341 |
14 | 0.982954 | 1.400361 |
7 | 0.979888 | 1.402408 |
4 | 0.981121 | 1.401396 |
3 | 0.983580 | 1.401332 |
min_samples_leaf | score | time | scoreTimeRatio |
---|---|---|---|
1 | 0.992414 | 0.188492 | 0.200671 |
51 | 0.002084 | 0.735618 | 0.000000 |
26 | 0.980393 | 0.762561 | 0.048920 |
14 | 0.995105 | 0.037944 | 1.000000 |
7 | 0.988716 | 0.547179 | 0.068798 |
10 | 0.998930 | 0.358303 | 0.106209 |
12 | 1.002084 | 1.037944 | 0.036709 |
I suppose when it gets this small we could use a regular Grid Search, but...maybe next week! Or maybe another variable! Or maybe benchmarks vs GridSearchCV
and/or RandomizedSearchCV
. Who knows what the future holds?