Skip to content
This repository was archived by the owner on Jun 17, 2024. It is now read-only.

Commit 8b2bf94

Browse files
committed
add module to be exported
Signed-off-by: Kosaku Kimura <[email protected]>
1 parent f083668 commit 8b2bf94

File tree

2 files changed

+117
-0
lines changed

2 files changed

+117
-0
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ build-backend = "poetry.core.masonry.api"
3232
[tool.poetry.plugins."code_block_generator"]
3333
loaddata = "sapientml_loaddata:LoadData"
3434

35+
[tool.poetry.plugins."export_module"]
36+
sample-dataset = "sapientml_loaddata.lib:sample_dataset"
37+
3538
[tool.pysen]
3639
version = "0.10"
3740

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from decimal import ROUND_HALF_UP, Decimal
2+
3+
import pandas as pd
4+
from sklearn.model_selection import train_test_split
5+
6+
7+
def _sampled_training(dev_training_dataset, train_size, stratify, task_type) -> pd.DataFrame:
8+
sampled_training_dataset, _ = train_test_split(
9+
dev_training_dataset,
10+
train_size=train_size,
11+
stratify=stratify if task_type == "classification" else None,
12+
)
13+
return sampled_training_dataset # type: ignore
14+
15+
16+
def sample_dataset(
17+
dataframe: pd.DataFrame,
18+
sample_size: int,
19+
target_columns: list[str],
20+
task_type: str,
21+
) -> pd.DataFrame:
22+
# Sample the training set if the dataset is big
23+
# FIXME
24+
sampled_training_dataset = None
25+
num_of_rows = len(dataframe.index)
26+
if num_of_rows >= sample_size:
27+
rare_labels = []
28+
dataframe_alltargets = None
29+
if task_type == "classification":
30+
dataframe_alltargets = dataframe[target_columns].astype(str).apply("".join, axis=1)
31+
label_count = dataframe_alltargets.value_counts()
32+
rare_labels = label_count.loc[label_count == 1].index.tolist()
33+
34+
if rare_labels and dataframe_alltargets is not None:
35+
dataframe_rare = dataframe[dataframe_alltargets.isin(rare_labels)]
36+
rare_index = dataframe_rare.index.values
37+
38+
dataframe_wo_rare = dataframe.drop(rare_index)
39+
40+
num_of_labels = [len(dataframe_wo_rare[target].value_counts()) for target in target_columns]
41+
42+
rare_to_all_ratio = int(
43+
Decimal(sample_size * len(dataframe_rare) / len(dataframe)).quantize(
44+
Decimal("0"), rounding=ROUND_HALF_UP
45+
)
46+
)
47+
not_rare_to_all_ratio = int(
48+
Decimal(sample_size * len(dataframe_wo_rare) / len(dataframe)).quantize(
49+
Decimal("0"), rounding=ROUND_HALF_UP
50+
)
51+
)
52+
53+
stratify_wo_rare = None
54+
55+
if len(dataframe_rare) == len(dataframe):
56+
sampled_training_dataset = _sampled_training(dataframe, sample_size, None, task_type)
57+
58+
elif rare_to_all_ratio in [0, 1]:
59+
sampled_training_dataset_rare = dataframe_rare
60+
61+
if max(num_of_labels) >= sample_size:
62+
stratify_wo_rare = None
63+
else:
64+
stratify_wo_rare = dataframe_wo_rare[target_columns]
65+
sampled_training_dataset_wo_rare = _sampled_training(
66+
dataframe_wo_rare,
67+
sample_size - len(sampled_training_dataset_rare),
68+
stratify_wo_rare,
69+
task_type,
70+
)
71+
72+
sampled_training_dataset = pd.concat(
73+
[sampled_training_dataset_wo_rare, sampled_training_dataset_rare] # type: ignore
74+
)
75+
76+
elif not_rare_to_all_ratio in [0, 1]:
77+
sampled_training_dataset_wo_rare = dataframe_wo_rare
78+
sampled_training_dataset_rare = _sampled_training(
79+
dataframe_rare,
80+
sample_size - len(sampled_training_dataset_wo_rare),
81+
None,
82+
task_type,
83+
)
84+
85+
sampled_training_dataset = pd.concat(
86+
[sampled_training_dataset_wo_rare, sampled_training_dataset_rare] # type: ignore
87+
)
88+
89+
else:
90+
if max(num_of_labels) >= sample_size:
91+
stratify_wo_rare = None
92+
else:
93+
stratify_wo_rare = dataframe_wo_rare[target_columns]
94+
95+
sampled_training_dataset_wo_rare = _sampled_training(
96+
dataframe_wo_rare, not_rare_to_all_ratio, stratify_wo_rare, task_type
97+
)
98+
sampled_training_dataset_rare = _sampled_training(dataframe_rare, rare_to_all_ratio, None, task_type)
99+
100+
sampled_training_dataset = pd.concat(
101+
[sampled_training_dataset_wo_rare, sampled_training_dataset_rare] # type: ignore
102+
)
103+
104+
else:
105+
num_of_labels = [len(dataframe[target].value_counts()) for target in target_columns]
106+
if max(num_of_labels) >= sample_size:
107+
stratify_wo_rare = None
108+
else:
109+
stratify_wo_rare = dataframe[target_columns]
110+
111+
sampled_training_dataset = _sampled_training(dataframe, sample_size, stratify_wo_rare, task_type)
112+
return sampled_training_dataset
113+
else:
114+
return dataframe

0 commit comments

Comments
 (0)