Skip to content

Commit 379c007

Browse files
committed
Added InnerPredictor class
1 parent 1a6393a commit 379c007

File tree

3 files changed

+129
-94
lines changed

3 files changed

+129
-94
lines changed

lib/lightgbm.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
require_relative "lightgbm/utils"
66
require_relative "lightgbm/booster"
77
require_relative "lightgbm/dataset"
8+
require_relative "lightgbm/inner_predictor"
89
require_relative "lightgbm/version"
910

1011
# scikit-learn API

lib/lightgbm/booster.rb

Lines changed: 10 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -141,52 +141,23 @@ def num_trees
141141
out.read_int
142142
end
143143

144-
def predict(data, start_iteration: 0, num_iteration: -1, raw_score: false, pred_leaf: false, pred_contrib: false, **kwargs)
144+
def predict(data, start_iteration: 0, num_iteration: nil, raw_score: false, pred_leaf: false, pred_contrib: false, **kwargs)
145+
predictor = InnerPredictor.from_booster(self, kwargs.transform_values(&:dup))
145146
if num_iteration.nil?
146147
if start_iteration <= 0
147148
num_iteration = best_iteration
148149
else
149150
num_iteration = -1
150151
end
151152
end
152-
153-
if data.is_a?(Dataset)
154-
raise TypeError, "Cannot use Dataset instance for prediction, please use raw data instead"
155-
end
156-
157-
predict_type = FFI::C_API_PREDICT_NORMAL
158-
if raw_score
159-
predict_type = FFI::C_API_PREDICT_RAW_SCORE
160-
end
161-
if pred_leaf
162-
predict_type = FFI::C_API_PREDICT_LEAF_INDEX
163-
end
164-
if pred_contrib
165-
predict_type = FFI::C_API_PREDICT_CONTRIB
166-
end
167-
168-
preds, nrow, singular =
169-
preds_for_data(
170-
data,
171-
start_iteration,
172-
num_iteration,
173-
predict_type,
174-
**kwargs
175-
)
176-
177-
if pred_leaf
178-
preds = preds.map(&:to_i)
179-
end
180-
181-
if preds.size != nrow
182-
if preds.size % nrow == 0
183-
preds = preds.each_slice(preds.size / nrow).to_a
184-
else
185-
raise Error, "Length of predict result (#{preds.size}) cannot be divide nrow (#{nrow})"
186-
end
187-
end
188-
189-
singular ? preds.first : preds
153+
predictor.predict(
154+
data,
155+
start_iteration: start_iteration,
156+
num_iteration: num_iteration,
157+
raw_score: raw_score,
158+
pred_leaf: pred_leaf,
159+
pred_contrib: pred_contrib
160+
)
190161
end
191162

192163
def save_model(filename, num_iteration: nil, start_iteration: 0)
@@ -261,61 +232,6 @@ def num_class
261232
out.read_int
262233
end
263234

264-
def preds_for_data(input, start_iteration, num_iteration, predict_type, **params)
265-
input =
266-
if daru?(input)
267-
input[*cached_feature_name].map_rows(&:to_a)
268-
elsif input.is_a?(Hash) # sort feature.values to match the order of model.feature_name
269-
sorted_feature_values(input)
270-
elsif input.is_a?(Array) && input.first.is_a?(Hash) # on multiple elems, if 1st is hash, assume they all are
271-
input.map(&method(:sorted_feature_values))
272-
elsif rover?(input)
273-
# TODO improve performance
274-
input[cached_feature_name].to_numo.to_a
275-
else
276-
input.to_a
277-
end
278-
279-
singular = !input.first.is_a?(Array)
280-
input = [input] if singular
281-
282-
nrow = input.count
283-
n_preds =
284-
num_preds(
285-
start_iteration,
286-
num_iteration,
287-
nrow,
288-
predict_type
289-
)
290-
291-
flat_input = input.flatten
292-
handle_missing(flat_input)
293-
data = ::FFI::MemoryPointer.new(:double, input.count * input.first.count)
294-
data.write_array_of_double(flat_input)
295-
296-
out_len = ::FFI::MemoryPointer.new(:int64)
297-
out_result = ::FFI::MemoryPointer.new(:double, n_preds)
298-
check_result FFI.LGBM_BoosterPredictForMat(handle_pointer, data, 1, input.count, input.first.count, 1, predict_type, start_iteration, num_iteration, params_str(params), out_len, out_result)
299-
300-
if n_preds != out_len.read_int64
301-
raise Error, "Wrong length for predict results"
302-
end
303-
304-
preds = out_result.read_array_of_double(out_len.read_int64)
305-
306-
[preds, nrow, singular]
307-
end
308-
309-
def num_preds(start_iteration, num_iteration, nrow, predict_type)
310-
out = ::FFI::MemoryPointer.new(:int64)
311-
check_result FFI.LGBM_BoosterCalcNumPredict(handle_pointer, nrow, predict_type, start_iteration, num_iteration, out)
312-
out.read_int64
313-
end
314-
315-
def sorted_feature_values(input_hash)
316-
input_hash.transform_keys(&:to_s).fetch_values(*cached_feature_name)
317-
end
318-
319235
def cached_feature_name
320236
@cached_feature_name ||= feature_name
321237
end

lib/lightgbm/inner_predictor.rb

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
module LightGBM
2+
class InnerPredictor
3+
def initialize(booster, pred_parameter)
4+
@booster = booster
5+
@pred_parameter = params_str(pred_parameter)
6+
end
7+
8+
def self.from_booster(booster, pred_parameter)
9+
new(booster, pred_parameter)
10+
end
11+
12+
def predict(data, start_iteration: 0, num_iteration: -1, raw_score: false, pred_leaf: false, pred_contrib: false)
13+
if data.is_a?(Dataset)
14+
raise TypeError, "Cannot use Dataset instance for prediction, please use raw data instead"
15+
end
16+
17+
predict_type = FFI::C_API_PREDICT_NORMAL
18+
if raw_score
19+
predict_type = FFI::C_API_PREDICT_RAW_SCORE
20+
end
21+
if pred_leaf
22+
predict_type = FFI::C_API_PREDICT_LEAF_INDEX
23+
end
24+
if pred_contrib
25+
predict_type = FFI::C_API_PREDICT_CONTRIB
26+
end
27+
28+
preds, nrow, singular =
29+
preds_for_data(
30+
data,
31+
start_iteration,
32+
num_iteration,
33+
predict_type
34+
)
35+
36+
if pred_leaf
37+
preds = preds.map(&:to_i)
38+
end
39+
40+
if preds.size != nrow
41+
if preds.size % nrow == 0
42+
preds = preds.each_slice(preds.size / nrow).to_a
43+
else
44+
raise Error, "Length of predict result (#{preds.size}) cannot be divide nrow (#{nrow})"
45+
end
46+
end
47+
48+
singular ? preds.first : preds
49+
end
50+
51+
private
52+
53+
def handle_pointer
54+
@booster.send(:handle_pointer)
55+
end
56+
57+
def preds_for_data(input, start_iteration, num_iteration, predict_type)
58+
input =
59+
if daru?(input)
60+
input[*cached_feature_name].map_rows(&:to_a)
61+
elsif input.is_a?(Hash) # sort feature.values to match the order of model.feature_name
62+
sorted_feature_values(input)
63+
elsif input.is_a?(Array) && input.first.is_a?(Hash) # on multiple elems, if 1st is hash, assume they all are
64+
input.map(&method(:sorted_feature_values))
65+
elsif rover?(input)
66+
# TODO improve performance
67+
input[cached_feature_name].to_numo.to_a
68+
else
69+
input.to_a
70+
end
71+
72+
singular = !input.first.is_a?(Array)
73+
input = [input] if singular
74+
75+
nrow = input.count
76+
n_preds =
77+
num_preds(
78+
start_iteration,
79+
num_iteration,
80+
nrow,
81+
predict_type
82+
)
83+
84+
flat_input = input.flatten
85+
handle_missing(flat_input)
86+
data = ::FFI::MemoryPointer.new(:double, input.count * input.first.count)
87+
data.write_array_of_double(flat_input)
88+
89+
out_len = ::FFI::MemoryPointer.new(:int64)
90+
out_result = ::FFI::MemoryPointer.new(:double, n_preds)
91+
check_result FFI.LGBM_BoosterPredictForMat(handle_pointer, data, 1, input.count, input.first.count, 1, predict_type, start_iteration, num_iteration, @pred_parameter, out_len, out_result)
92+
93+
if n_preds != out_len.read_int64
94+
raise Error, "Wrong length for predict results"
95+
end
96+
97+
preds = out_result.read_array_of_double(out_len.read_int64)
98+
99+
[preds, nrow, singular]
100+
end
101+
102+
def num_preds(start_iteration, num_iteration, nrow, predict_type)
103+
out = ::FFI::MemoryPointer.new(:int64)
104+
check_result FFI.LGBM_BoosterCalcNumPredict(handle_pointer, nrow, predict_type, start_iteration, num_iteration, out)
105+
out.read_int64
106+
end
107+
108+
def sorted_feature_values(input_hash)
109+
input_hash.transform_keys(&:to_s).fetch_values(*cached_feature_name)
110+
end
111+
112+
def cached_feature_name
113+
@booster.send(:cached_feature_name)
114+
end
115+
116+
include Utils
117+
end
118+
end

0 commit comments

Comments
 (0)