[search] Added confidence interval calculations via bootstrap.

This commit is contained in:
Yuri Gorshenin 2017-09-05 18:47:43 +03:00 committed by mpimenov
parent b1b1edbbc2
commit 26f3fcfb6e

View file

@ -4,6 +4,7 @@ from math import exp, log
from scipy.stats import pearsonr
from sklearn import svm
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.utils import resample
import argparse
import collections
import itertools
@ -20,6 +21,8 @@ NAME_SCORES = ['Zero', 'Substring', 'Prefix', 'Full Match']
SEARCH_TYPES = ['POI', 'Building', 'Street', 'Unclassified', 'Village', 'City', 'State', 'Country']
FEATURES = ['DistanceToPivot', 'Rank', 'FalseCats', 'ErrorsMade'] + NAME_SCORES + SEARCH_TYPES
BOOTSTRAP_ITERATIONS = 10000
def transform_name_score(value, categories_match):
if categories_match == 1:
@ -220,6 +223,36 @@ def cpp_output(features, ws):
print_array('kType', 'Model::TYPE_COUNT', st)
def show_bootstrap_statistics(clf, X, y, features):
num_features = len(features)
coefs = []
for i in range(num_features):
coefs.append([])
for _ in range(BOOTSTRAP_ITERATIONS):
X_sample, y_sample = resample(X, y)
clf.fit(X_sample, y_sample)
for i, c in enumerate(get_normalized_coefs(clf)):
coefs[i].append(c)
intervals = []
print()
print('***** Bootstrap 95% confidence intervals *****')
for i, cs in enumerate(coefs):
values = np.array(cs)
lo = np.percentile(values, 2.5)
hi = np.percentile(values, 97.5)
print('{}: ({:.3f}, {:.3f})'.format(FEATURES[i], lo, hi))
def get_normalized_coefs(clf):
ws = clf.coef_[0]
max_w = max(abs(w) for w in ws)
return np.divide(ws, max_w)
def main(args):
data = pd.read_csv(sys.stdin)
normalize_data(data)
@ -238,9 +271,9 @@ def main(args):
gs = GridSearchCV(clf, grid, scoring='roc_auc', cv=cv)
gs.fit(xs, ys)
ws = gs.best_estimator_.coef_[0]
max_w = max(abs(w) for w in ws)
ws = np.divide(ws, max_w)
print('Best params: {}'.format(gs.best_params_))
ws = get_normalized_coefs(gs.best_estimator_)
# Following code restores coeffs for merged features.
ws[FEATURES.index('Building')] = ws[FEATURES.index('POI')]
@ -261,11 +294,15 @@ def main(args):
else:
raw_output(FEATURES, ws)
if args.bootstrap:
show_bootstrap_statistics(clf, xs, ys, FEATURES)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--seed', help='random seed', type=int)
parser.add_argument('--pearson', help='show pearson statistics', action='store_true')
parser.add_argument('--cpp', help='generate output in the C++ format', action='store_true')
parser.add_argument('--bootstrap', help='show bootstrap confidence intervals', action='store_true')
args = parser.parse_args()
main(args)