Skip to content

Commit 8aadbaf

Browse files
committed
Added support for Rover data frames to predict method
1 parent 3792f19 commit 8aadbaf

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## 0.4.0 (unreleased)
22

3-
- Added support for hashes to `predict` method
3+
- Added support for hashes and Rover data frames to `predict` method
44
- Changed `Dataset` to use column names for feature names with Rover and Daru
55
- Dropped support for Ruby < 3.1
66

lib/lightgbm/booster.rb

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def predict(input, start_iteration: nil, num_iteration: nil, **params)
146146
sorted_feature_values(input)
147147
elsif input.is_a?(Array) && input.first.is_a?(Hash) # on multiple elems, if 1st is hash, assume they all are
148148
input.map(&method(:sorted_feature_values))
149+
elsif rover?(input)
150+
# TODO improve performance
151+
input[feature_name()].to_numo.to_a
149152
else
150153
input.to_a
151154
end

test/booster_test.rb

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,17 @@ def test_predict_hash
5050
end
5151
end
5252

53+
def test_predict_rover
54+
require "rover"
55+
x_test =
56+
Rover::DataFrame.new([
57+
{"x3" => 9.0, "x2" => 7.2, "x1" => 1.2, "x0" => 3.7},
58+
{"x3" => 0.0, "x2" => 7.9, "x1" => 0.5, "x0" => 7.5},
59+
])
60+
pred = booster.predict(x_test)
61+
assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], pred.first(2)
62+
end
63+
5364
def test_model_to_string
5465
assert booster.model_to_string
5566
end

0 commit comments

Comments
 (0)