17
17
"""Tuner with droplet algorithm"""
18
18
19
19
import logging
20
+ import os
20
21
import numpy as np
21
22
from scipy import stats
22
23
from .tuner import Tuner
@@ -44,13 +45,15 @@ def __init__(self, task, start_position=None, pvalue=0.05):
44
45
45
46
for _ , v in self .space .space_map .items ():
46
47
self .dims .append (len (v ))
48
+ if len (self .dims ) == 0 :
49
+ self .dims .append (1 )
47
50
48
51
# start position
49
52
start_position = [0 ] * len (self .dims ) if start_position is None else start_position
50
53
self .best_choice = (- 1 , [0 ] * len (self .dims ), [99999 ])
51
54
self .visited = set ([self .space .knob2point (start_position )])
52
- self .execution , self .total_execution , self .batch = 1 , max (self .dims ), 16
53
- self .pvalue , self .step = pvalue , 1
55
+ self .execution , self .total_execution , self .pvalue = 1 , max (self .dims ), pvalue
56
+ self .step , self .iter , self . batch = 1 , 0 , max ( 16 , os . cpu_count ())
54
57
self .next = [(self .space .knob2point (start_position ), start_position )]
55
58
56
59
def num_to_bin (self , value , factor = 1 ):
@@ -100,14 +103,15 @@ def speculation(self):
100
103
self .next += self .next_pos (self .search_space (self .execution ))
101
104
102
105
def update (self , inputs , results ):
103
- found_best_pos = False
106
+ found_best_pos , count_valids = False , 0
104
107
for i , (_ , res ) in enumerate (zip (inputs , results )):
105
108
try :
106
109
if np .mean (self .best_choice [2 ]) > np .mean (res .costs ) and self .p_value (
107
110
self .best_choice [2 ], res .costs
108
111
):
109
112
self .best_choice = (self .next [i ][0 ], self .next [i ][1 ], res .costs )
110
113
found_best_pos = True
114
+ count_valids += 1
111
115
except TypeError :
112
116
LOGGER .debug ("Solution is not valid" )
113
117
continue
@@ -119,6 +123,13 @@ def update(self, inputs, results):
119
123
self .next += self .next_pos (self .search_space ())
120
124
self .execution = 1
121
125
self .speculation ()
126
+ # stop, because all neighborhoods are invalid.
127
+ if count_valids == 0 and self .iter > 3 :
128
+ self .next = []
129
+ LOGGER .warning (
130
+ f"Warning: early termination due to an all-invalid neighborhood \
131
+ after { self .iter } iterations"
132
+ )
122
133
123
134
def has_next (self ):
124
135
return len (self .next ) > 0
0 commit comments