Skip to content

Commit a5c886e

Browse files
Merge pull request #192 from Thilakraj1998/main
Minor Changes
2 parents 12fc29f + 1b4e416 commit a5c886e

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

blobcity/config/classifier_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class classifier_config:
136136
'reg_lambda': {'float':[1e-3,0.1]},
137137
'booster':{'str':['gbtree', 'gblinear','dart']},
138138
'verbosity':{'str':[0]},
139+
'n_jobs':{'str':[1]}
139140
}
140141
],
141142
"RadiusNeighborsClassifier":[

blobcity/config/regressor_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ class regressor_config:
221221
'reg_alpha': {'float':[1e-3,0.1]},
222222
'reg_lambda': {'float':[1e-3,0.1]},
223223
'verbosity':{'str':[0]},
224+
'n_jobs':{'str':[1]}
224225
}
225226
],
226227
"GammaRegressor":[

blobcity/main/modelSelection.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,23 @@ def sort_score(modelScore):
8787
sorted_dict=dict(sorted(modelScore.items(), key=lambda item: item[1],reverse=True))
8888
return sorted_dict
8989

90+
def eval_model(models,m,X,Y,k):
91+
"""
92+
param1: dictionary
93+
param2: string
94+
param3: pd.DataFrame
95+
param4: pd.Dataframe/pd.Series/numpy.array
96+
param5: int
97+
return: float
98+
99+
Function to fetch cross validation score for specific models from the dictionary
100+
"""
101+
if m in ['XGBClassifier','XGBRegressor']: model=models[m][0](verbosity=0,n_jobs=1)
102+
elif m in ['CatBoostRegressor','CatBoostClassifier']: model=models[m][0](verbose=False)
103+
elif m in ['LGBMClassifier','LGBMRegressor']: model=models[m][0](verbose=-1,n_jobs=1)
104+
else: model=models[m][0]()
105+
return cv_score(model,X,Y,k)
106+
90107
def train_on_sample_data(dataframe,target,models):
91108
"""
92109
param1: pandas.DataFrame
@@ -107,10 +124,7 @@ def train_on_sample_data(dataframe,target,models):
107124
modelScore={}
108125
prog.create_progressbar(len(models),"Quick Search (Stage 1 of 3) :")
109126
for m in models:
110-
if m in ['XGBClassifier','XGBRegressor']: model=models[m][0](verbosity=0)
111-
elif m in ['CatBoostRegressor','CatBoostClassifier']: model=models[m][0](verbose=False)
112-
else: model=models[m][0]()
113-
modelScore[m]=cv_score(model,X,Y,k)
127+
modelScore[m]=eval_model(models,m,X,Y,k)
114128
prog.trials=prog.trials-1
115129
prog.update_progressbar(1)
116130
prog.update_progressbar(prog.trials)
@@ -133,10 +147,7 @@ def train_on_full_data(X,Y,models,best):
133147
modelScore={}
134148
prog.create_progressbar(len(best),"Deep Search (Stage 2 of 3) :")
135149
for m in best:
136-
if m in ['XGBClassifier','XGBRegressor']: model=models[m][0](verbosity=0)
137-
elif m in ['CatBoostRegressor','CatBoostClassifier']: model=models[m][0](verbose=False)
138-
else: model=models[m][0]()
139-
modelScore[m]=cv_score(model,X,Y,k)
150+
modelScore[m]=eval_model(models,m,X,Y,k)
140151
prog.trials=prog.trials-1
141152
prog.update_progressbar(1)
142153
prog.update_progressbar(prog.trials)

0 commit comments

Comments
 (0)