@@ -141,52 +141,23 @@ def num_trees
141
141
out . read_int
142
142
end
143
143
144
- def predict ( data , start_iteration : 0 , num_iteration : -1 , raw_score : false , pred_leaf : false , pred_contrib : false , **kwargs )
144
+ def predict ( data , start_iteration : 0 , num_iteration : nil , raw_score : false , pred_leaf : false , pred_contrib : false , **kwargs )
145
+ predictor = InnerPredictor . from_booster ( self , kwargs . transform_values ( &:dup ) )
145
146
if num_iteration . nil?
146
147
if start_iteration <= 0
147
148
num_iteration = best_iteration
148
149
else
149
150
num_iteration = -1
150
151
end
151
152
end
152
-
153
- if data . is_a? ( Dataset )
154
- raise TypeError , "Cannot use Dataset instance for prediction, please use raw data instead"
155
- end
156
-
157
- predict_type = FFI ::C_API_PREDICT_NORMAL
158
- if raw_score
159
- predict_type = FFI ::C_API_PREDICT_RAW_SCORE
160
- end
161
- if pred_leaf
162
- predict_type = FFI ::C_API_PREDICT_LEAF_INDEX
163
- end
164
- if pred_contrib
165
- predict_type = FFI ::C_API_PREDICT_CONTRIB
166
- end
167
-
168
- preds , nrow , singular =
169
- preds_for_data (
170
- data ,
171
- start_iteration ,
172
- num_iteration ,
173
- predict_type ,
174
- **kwargs
175
- )
176
-
177
- if pred_leaf
178
- preds = preds . map ( &:to_i )
179
- end
180
-
181
- if preds . size != nrow
182
- if preds . size % nrow == 0
183
- preds = preds . each_slice ( preds . size / nrow ) . to_a
184
- else
185
- raise Error , "Length of predict result (#{ preds . size } ) cannot be divide nrow (#{ nrow } )"
186
- end
187
- end
188
-
189
- singular ? preds . first : preds
153
+ predictor . predict (
154
+ data ,
155
+ start_iteration : start_iteration ,
156
+ num_iteration : num_iteration ,
157
+ raw_score : raw_score ,
158
+ pred_leaf : pred_leaf ,
159
+ pred_contrib : pred_contrib
160
+ )
190
161
end
191
162
192
163
def save_model ( filename , num_iteration : nil , start_iteration : 0 )
@@ -261,61 +232,6 @@ def num_class
261
232
out . read_int
262
233
end
263
234
264
- def preds_for_data ( input , start_iteration , num_iteration , predict_type , **params )
265
- input =
266
- if daru? ( input )
267
- input [ *cached_feature_name ] . map_rows ( &:to_a )
268
- elsif input . is_a? ( Hash ) # sort feature.values to match the order of model.feature_name
269
- sorted_feature_values ( input )
270
- elsif input . is_a? ( Array ) && input . first . is_a? ( Hash ) # on multiple elems, if 1st is hash, assume they all are
271
- input . map ( &method ( :sorted_feature_values ) )
272
- elsif rover? ( input )
273
- # TODO improve performance
274
- input [ cached_feature_name ] . to_numo . to_a
275
- else
276
- input . to_a
277
- end
278
-
279
- singular = !input . first . is_a? ( Array )
280
- input = [ input ] if singular
281
-
282
- nrow = input . count
283
- n_preds =
284
- num_preds (
285
- start_iteration ,
286
- num_iteration ,
287
- nrow ,
288
- predict_type
289
- )
290
-
291
- flat_input = input . flatten
292
- handle_missing ( flat_input )
293
- data = ::FFI ::MemoryPointer . new ( :double , input . count * input . first . count )
294
- data . write_array_of_double ( flat_input )
295
-
296
- out_len = ::FFI ::MemoryPointer . new ( :int64 )
297
- out_result = ::FFI ::MemoryPointer . new ( :double , n_preds )
298
- check_result FFI . LGBM_BoosterPredictForMat ( handle_pointer , data , 1 , input . count , input . first . count , 1 , predict_type , start_iteration , num_iteration , params_str ( params ) , out_len , out_result )
299
-
300
- if n_preds != out_len . read_int64
301
- raise Error , "Wrong length for predict results"
302
- end
303
-
304
- preds = out_result . read_array_of_double ( out_len . read_int64 )
305
-
306
- [ preds , nrow , singular ]
307
- end
308
-
309
- def num_preds ( start_iteration , num_iteration , nrow , predict_type )
310
- out = ::FFI ::MemoryPointer . new ( :int64 )
311
- check_result FFI . LGBM_BoosterCalcNumPredict ( handle_pointer , nrow , predict_type , start_iteration , num_iteration , out )
312
- out . read_int64
313
- end
314
-
315
- def sorted_feature_values ( input_hash )
316
- input_hash . transform_keys ( &:to_s ) . fetch_values ( *cached_feature_name )
317
- end
318
-
319
235
def cached_feature_name
320
236
@cached_feature_name ||= feature_name
321
237
end
0 commit comments