Skip to content

Commit 73ff8cf

Browse files
authored
Create funlbm_test.py
1 parent ff6a00b commit 73ff8cf

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

pytest/funlbm_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
import pytest
3+
from BiFuncLib.simulation_data import lbm_sim_data
4+
from BiFuncLib.lbm_bifunc import lbm_bifunc
5+
from BiFuncLib.lbm_main_func import ari
6+
7+
def _check_lbm_result(res: dict):
8+
assert isinstance(res, dict)
9+
assert 'row_clust' in res
10+
assert 'col_clust' in res
11+
assert len(res['row_clust']) > 0
12+
assert len(res['col_clust']) > 0
13+
14+
def test_lbm_sim_data():
15+
lbm_sim = lbm_sim_data(n=50, p=50, t=15, seed=42)
16+
assert isinstance(lbm_sim, dict)
17+
assert 'data' in lbm_sim
18+
assert lbm_sim['data'].shape == (50, 15)
19+
20+
def test_lbm_bifunc_basic():
21+
"""最基本的 LBM 调用"""
22+
lbm_sim = lbm_sim_data(n=60, p=60, t=20, seed=123)
23+
data = lbm_sim['data']
24+
res = lbm_bifunc(data, K=4, L=3, display=False, basis_name = 'spline', init='funFEM')
25+
_check_lbm_result(res)
26+
row_ari = ari(res['row_clust'], lbm_sim['row_clust'])
27+
col_ari = ari(res['col_clust'], lbm_sim['col_clust'])
28+
assert 0 <= row_ari <= 1
29+
assert 0 <= col_ari <= 1
30+
31+
def test_lbm_bifunc_grid():
32+
lbm_sim = lbm_sim_data(n=40, p=40, t=10, bivariate=True, seed=456)
33+
data = [lbm_sim['data1'], lbm_sim['data2']]
34+
res = lbm_bifunc(data, K=[2, 3], L=[2, 3], display=False)
35+
_check_lbm_result(res)
36+
37+
def test_lbm_bifunc_user_init():
38+
lbm_sim = lbm_sim_data(n=30, p=30, t=10, seed=789)
39+
data = lbm_sim['data']
40+
res0 = lbm_bifunc(data, K=3, L=2, display=True, init='kmeans')
41+
res1 = lbm_bifunc(data, K=[res0['K']], L=[res0['L']],
42+
init='user',
43+
row_init=res0['row_clust'],
44+
col_init=res0['col_clust'],
45+
display=False)
46+
_check_lbm_result(res1)

0 commit comments

Comments
 (0)