Skip to content

Commit 19b1438

Browse files
committed
added migrations for standard functions of groupBy with tests
1 parent 40a8b5c commit 19b1438

File tree

5 files changed

+104
-13
lines changed

5 files changed

+104
-13
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/forEach.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ public inline fun <T> DataFrame<T>.forEach(action: RowExpression<T, Unit>): Unit
2020

2121
// region GroupBy
2222

23-
@Deprecated("Replaced with forEachEntry")
23+
@Deprecated(
24+
"Replaced with forEachEntry",
25+
ReplaceWith("forEachEntry { val key = it\nval group = it.group()\nbody(key, group) }"),
26+
)
2427
public inline fun <T, G> GroupBy<T, G>.forEach(body: (GroupBy.Entry<T, G>) -> Unit): Unit =
2528
keys.forEach { key ->
2629
val group = groups[key.index()]

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ public interface GroupBy<out T, out G> : Grouped<G> {
122122

123123
public fun <R> updateGroups(transform: Selector<DataFrame<G>, DataFrame<R>>): GroupBy<T, R>
124124

125-
@Deprecated("Replaced by filterEntries")
125+
@Deprecated("Replaced by filterEntries", ReplaceWith("filterEntries(predicate)"))
126126
public fun filter(predicate: GroupedRowFilter<T, G>): GroupBy<T, G>
127127

128128
public fun filterEntries(predicate: GroupByEntryFilter<T, G>): GroupBy<T, G>

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,19 +142,28 @@ public inline fun <T> DataFrame<T>.mapToFrame(body: AddDsl<T>.() -> Unit): AnyFr
142142

143143
// region GroupBy
144144

145-
@Deprecated("Replaced by mapEntries")
145+
@Deprecated(
146+
"Replaced by mapEntries",
147+
ReplaceWith("mapEntries { val key = it\nval group = it.group()\nbody(key, group) }"),
148+
)
146149
public inline fun <T, G, R> GroupBy<T, G>.map(body: Selector<GroupWithKey<T, G>, R>): List<R> =
147150
keys.rows().mapIndexedNotNull { index, row ->
148151
val group = groups[index]
149152
val g = GroupWithKey(row, group)
150153
body(g, g)
151154
}
152155

153-
@Deprecated("Replaced by mapEntriesToRows")
156+
@Deprecated(
157+
"Replaced by mapEntriesToRows",
158+
ReplaceWith("mapEntriesToRows { val key = it\nval group = it.group()\nbody(key, group) }"),
159+
)
154160
public fun <T, G> GroupBy<T, G>.mapToRows(body: Selector<GroupWithKey<T, G>, DataRow<G>?>): DataFrame<G> =
155161
map(body).concat()
156162

157-
@Deprecated("Replaced by mapEntriesToFrames")
163+
@Deprecated(
164+
"Replaced by mapEntriesToFrames",
165+
ReplaceWith("mapEntriesToFrames { val key = it\nval group = it.group()\nbody(key, group) }"),
166+
)
158167
public fun <T, G> GroupBy<T, G>.mapToFrames(body: Selector<GroupWithKey<T, G>, DataFrame<G>>): FrameColumn<G> =
159168
DataColumn.createFrameColumn(groups.name, map(body))
160169

@@ -164,10 +173,11 @@ public inline fun <T, G, R> GroupBy<T, G>.mapEntries(body: GroupByEntrySelector<
164173
body(entry, entry)
165174
}
166175

167-
public fun <T, G> GroupBy<T, G>.mapEntriesToRows(body: GroupByEntrySelector<T, G, DataRow<G>?>): DataFrame<G> =
176+
public fun <T, G, R : Any> GroupBy<T, G>.mapEntriesToRows(body: GroupByEntrySelector<T, G, DataRow<R>?>): DataFrame<R> =
168177
mapEntries(body).concat()
169178

170-
public fun <T, G> GroupBy<T, G>.mapEntriesToFrames(body: GroupByEntrySelector<T, G, DataFrame<G>>): FrameColumn<G> =
171-
DataColumn.createFrameColumn(groups.name, mapEntries(body))
179+
public fun <T, G, R : Any> GroupBy<T, G>.mapEntriesToFrames(
180+
body: GroupByEntrySelector<T, G, DataFrame<R>>,
181+
): FrameColumn<R> = DataColumn.createFrameColumn(groups.name, mapEntries(body))
172182

173183
// endregion

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ internal class GroupByImpl<T, G>(
4747
override val keys by lazy { df.remove { groups } }
4848

4949
@Suppress("UNCHECKED_CAST")
50-
override fun <R> updateGroups(transform: Selector<DataFrame<G>, DataFrame<R>>) =
50+
override fun <R> updateGroups(transform: Selector<DataFrame<G>, DataFrame<R>>): GroupBy<T, R> =
5151
df.convert { groups }.with { transform(it, it) }
52-
.asGroupBy(groups.name()) as GroupBy<T, R>
52+
.asGroupBy { frameCol<R>(groups.name()) }
5353

5454
override fun toString() = df.toString()
5555

@@ -62,15 +62,15 @@ internal class GroupByImpl<T, G>(
6262
val row = GroupedDataRowImpl(df[it], groups)
6363
predicate(row, row)
6464
}
65-
return df[indices].asGroupBy { groups }
65+
return df[indices].asGroupBy { frameCol<G>(groups.name()) }
6666
}
6767

6868
override fun filterEntries(predicate: GroupByEntryFilter<T, G>): GroupBy<T, G> {
6969
val indices = (0 until df.nrow).filter {
7070
val row = GroupByEntryImpl(df[it], groups)
7171
predicate(row, row)
7272
}
73-
return df[indices].asGroupBy { groups }
73+
return df[indices].asGroupBy { frameCol<G>(groups.name()) }
7474
}
7575

7676
override fun toDataFrame(groupedColumnName: String?): DataFrame<T> =

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package org.jetbrains.kotlinx.dataframe.api
22

33
import io.kotest.matchers.shouldBe
44
import org.jetbrains.kotlinx.dataframe.AnyFrame
5+
import org.jetbrains.kotlinx.dataframe.testSets.person.BaseTest
6+
import org.jetbrains.kotlinx.dataframe.testSets.person.age
57
import org.junit.Test
68
import kotlin.reflect.typeOf
79

810
@Suppress("ktlint:standard:argument-list-wrapping")
9-
class GroupByTests {
11+
class GroupByTests : BaseTest() {
1012

1113
@Test
1214
fun `groupBy values with nulls`() {
@@ -56,4 +58,80 @@ class GroupByTests {
5658
getFrameColumn("d") into "e"
5759
}["e"].type() shouldBe typeOf<List<AnyFrame>>()
5860
}
61+
62+
@Test
63+
fun `groupBy forEachEntry`() {
64+
val grouped = typed.groupBy { age }
65+
val entries1 = buildList {
66+
grouped.forEach { (key, group) ->
67+
add(key.toMap() to group)
68+
}
69+
}
70+
val entries2 = buildList {
71+
grouped.forEachEntry {
72+
add(it.toMap() to it.group())
73+
}
74+
}
75+
76+
entries1 shouldBe entries2
77+
}
78+
79+
@Test
80+
fun `groupBy mapEntries`() {
81+
// old mapToRows, and mapToFrames stick to the same return type, so let's make the types Any? for the test
82+
val grouped: GroupBy<Any?, Any?> = df.groupBy { age }
83+
val entries1 = grouped.map { (key, group) ->
84+
key.toMap() to group
85+
}
86+
val entries2 = grouped.mapEntries {
87+
it.toMap() to it.group()
88+
}
89+
entries1 shouldBe entries2
90+
91+
val entries3 = grouped.mapToRows { (key, group) ->
92+
listOf(key.toMap() to group).toDataFrame().single()
93+
}
94+
val entries4 = grouped.mapEntriesToRows {
95+
listOf(it.toMap() to it.group()).toDataFrame().single()
96+
}
97+
entries3 shouldBe entries4
98+
99+
val entries5 = grouped.mapToFrames { (key, group) ->
100+
listOf(key.toMap() to group).toDataFrame()
101+
}
102+
val entries6 = grouped.mapEntriesToFrames {
103+
listOf(it.toMap() to it.group()).toDataFrame()
104+
}
105+
entries5 shouldBe entries6
106+
107+
// let's test the -Entries variants with typed versions
108+
val grouped2 = typed.groupBy { age }
109+
110+
val entries7 = grouped2.mapEntries {
111+
it.toMap() to it.group()
112+
}
113+
val entries8 = grouped2.mapEntriesToRows {
114+
listOf(it.toMap() to it.group()).toDataFrame().single()
115+
}.toList()
116+
val entries9 = grouped2.mapEntriesToFrames {
117+
listOf(it.toMap() to it.group()).toDataFrame()
118+
}.map { it[0][0] to it[0][1] }.toList()
119+
entries7 shouldBe entries8
120+
entries8 shouldBe entries9
121+
}
122+
123+
@Test
124+
fun `groupBy filterEntries`() {
125+
val grouped = typed.groupBy { age }
126+
127+
val entries1 = grouped.filter { age == 20 }
128+
.mapEntries {
129+
it.toMap() to it.group()
130+
}
131+
val entries2 = grouped.filterEntries { age == 20 }
132+
.mapEntries {
133+
it.toMap() to it.group()
134+
}
135+
entries1 shouldBe entries2
136+
}
59137
}

0 commit comments

Comments
 (0)