Skip to content

Commit 9d350ba

Browse files
committed
Added support for hashes to Dataset
1 parent 0891638 commit 9d350ba

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.4.0 (unreleased)
22

33
- Added support for hashes and Rover data frames to `predict` method
4+
- Added support for hashes to `Dataset`
45
- Changed `Dataset` to use column names for feature names with Rover and Daru
56
- Changed `predict` method to match feature names with Daru
67
- Dropped support for Ruby < 3.1

lib/lightgbm/dataset.rb

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def group=(group)
6666
end
6767

6868
def feature_name=(feature_names)
69+
feature_names = feature_names.map(&:to_s)
6970
@feature_names = feature_names
7071
c_feature_names = ::FFI::MemoryPointer.new(:pointer, feature_names.size)
7172
# keep reference to string pointers
@@ -154,6 +155,14 @@ def construct
154155
end
155156
data = data.to_numo
156157
nrow, ncol = data.shape
158+
elsif data.is_a?(Array) && data.first.is_a?(Hash)
159+
keys = data.first.keys
160+
if @feature_name == "auto"
161+
@feature_name = keys
162+
end
163+
nrow = data.count
164+
ncol = data.first.count
165+
flat_data = data.flat_map { |v| v.fetch_values(*keys) }
157166
else
158167
nrow = data.count
159168
ncol = data.first.count

test/dataset_test.rb

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ def test_dump_text
5050
assert File.exist?(tempfile)
5151
end
5252

53+
def test_hashes_string_keys
54+
data = [{"x0" => 1, "x1" => 2}, {"x0" => 3, "x1" => 4}, {"x0" => 5, "x1" => 6}]
55+
dataset = LightGBM::Dataset.new(data)
56+
assert_equal 3, dataset.num_data
57+
assert_equal 2, dataset.num_feature
58+
assert_equal ["x0", "x1"], dataset.feature_name
59+
end
60+
61+
def test_hashes_symbol_keys
62+
data = [{x0: 1, x1: 2}, {x0: 3, x1: 4}, {x0: 5, x1: 6}]
63+
dataset = LightGBM::Dataset.new(data)
64+
assert_equal 3, dataset.num_data
65+
assert_equal 2, dataset.num_feature
66+
assert_equal ["x0", "x1"], dataset.feature_name
67+
end
68+
5369
def test_matrix
5470
data = Matrix.build(3, 3) { |row, col| row + col }
5571
label = Vector.elements([4, 5, 6])

0 commit comments

Comments
 (0)