|
| 1 | +import numpy as np |
| 2 | +import pytest |
| 3 | +from BiFuncLib.simulation_data import lbm_sim_data |
| 4 | +from BiFuncLib.lbm_bifunc import lbm_bifunc |
| 5 | +from BiFuncLib.lbm_main_func import ari |
| 6 | + |
| 7 | +def _check_lbm_result(res: dict): |
| 8 | + assert isinstance(res, dict) |
| 9 | + assert 'row_clust' in res |
| 10 | + assert 'col_clust' in res |
| 11 | + assert len(res['row_clust']) > 0 |
| 12 | + assert len(res['col_clust']) > 0 |
| 13 | + |
| 14 | +def test_lbm_sim_data(): |
| 15 | + lbm_sim = lbm_sim_data(n=50, p=50, t=15, seed=42) |
| 16 | + assert isinstance(lbm_sim, dict) |
| 17 | + assert 'data' in lbm_sim |
| 18 | + assert lbm_sim['data'].shape == (50, 15) |
| 19 | + |
| 20 | +def test_lbm_bifunc_basic(): |
| 21 | + """最基本的 LBM 调用""" |
| 22 | + lbm_sim = lbm_sim_data(n=60, p=60, t=20, seed=123) |
| 23 | + data = lbm_sim['data'] |
| 24 | + res = lbm_bifunc(data, K=4, L=3, display=False, basis_name = 'spline', init='funFEM') |
| 25 | + _check_lbm_result(res) |
| 26 | + row_ari = ari(res['row_clust'], lbm_sim['row_clust']) |
| 27 | + col_ari = ari(res['col_clust'], lbm_sim['col_clust']) |
| 28 | + assert 0 <= row_ari <= 1 |
| 29 | + assert 0 <= col_ari <= 1 |
| 30 | + |
| 31 | +def test_lbm_bifunc_grid(): |
| 32 | + lbm_sim = lbm_sim_data(n=40, p=40, t=10, bivariate=True, seed=456) |
| 33 | + data = [lbm_sim['data1'], lbm_sim['data2']] |
| 34 | + res = lbm_bifunc(data, K=[2, 3], L=[2, 3], display=False) |
| 35 | + _check_lbm_result(res) |
| 36 | + |
| 37 | +def test_lbm_bifunc_user_init(): |
| 38 | + lbm_sim = lbm_sim_data(n=30, p=30, t=10, seed=789) |
| 39 | + data = lbm_sim['data'] |
| 40 | + res0 = lbm_bifunc(data, K=3, L=2, display=True, init='kmeans') |
| 41 | + res1 = lbm_bifunc(data, K=[res0['K']], L=[res0['L']], |
| 42 | + init='user', |
| 43 | + row_init=res0['row_clust'], |
| 44 | + col_init=res0['col_clust'], |
| 45 | + display=False) |
| 46 | + _check_lbm_result(res1) |
0 commit comments