Skip to content

Commit a1a1a7c

Browse files
authored
[AUTOTVM][FIX] Typo fixes and add a warning in the Droplet Search (apache#16289)
Typo fixes and add warning
1 parent 9caa179 commit a1a1a7c

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

python/tvm/autotvm/tuner/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@
2828
from .index_based_tuner import GridSearchTuner, RandomTuner
2929
from .ga_tuner import GATuner
3030
from .xgboost_tuner import XGBTuner
31-
from .droplet_turner import DropletTuner
31+
from .droplet_tuner import DropletTuner

python/tvm/autotvm/tuner/droplet_turner.py renamed to python/tvm/autotvm/tuner/droplet_tuner.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""Tuner with droplet algorithm"""
1818

1919
import logging
20+
import os
2021
import numpy as np
2122
from scipy import stats
2223
from .tuner import Tuner
@@ -44,13 +45,15 @@ def __init__(self, task, start_position=None, pvalue=0.05):
4445

4546
for _, v in self.space.space_map.items():
4647
self.dims.append(len(v))
48+
if len(self.dims) == 0:
49+
self.dims.append(1)
4750

4851
# start position
4952
start_position = [0] * len(self.dims) if start_position is None else start_position
5053
self.best_choice = (-1, [0] * len(self.dims), [99999])
5154
self.visited = set([self.space.knob2point(start_position)])
52-
self.execution, self.total_execution, self.batch = 1, max(self.dims), 16
53-
self.pvalue, self.step = pvalue, 1
55+
self.execution, self.total_execution, self.pvalue = 1, max(self.dims), pvalue
56+
self.step, self.iter, self.batch = 1, 0, max(16, os.cpu_count())
5457
self.next = [(self.space.knob2point(start_position), start_position)]
5558

5659
def num_to_bin(self, value, factor=1):
@@ -100,14 +103,15 @@ def speculation(self):
100103
self.next += self.next_pos(self.search_space(self.execution))
101104

102105
def update(self, inputs, results):
103-
found_best_pos = False
106+
found_best_pos, count_valids = False, 0
104107
for i, (_, res) in enumerate(zip(inputs, results)):
105108
try:
106109
if np.mean(self.best_choice[2]) > np.mean(res.costs) and self.p_value(
107110
self.best_choice[2], res.costs
108111
):
109112
self.best_choice = (self.next[i][0], self.next[i][1], res.costs)
110113
found_best_pos = True
114+
count_valids += 1
111115
except TypeError:
112116
LOGGER.debug("Solution is not valid")
113117
continue
@@ -119,6 +123,13 @@ def update(self, inputs, results):
119123
self.next += self.next_pos(self.search_space())
120124
self.execution = 1
121125
self.speculation()
126+
# stop, because all neighborhoods are invalid.
127+
if count_valids == 0 and self.iter > 3:
128+
self.next = []
129+
LOGGER.warning(
130+
f"Warning: early termination due to an all-invalid neighborhood \
131+
after {self.iter} iterations"
132+
)
122133

123134
def has_next(self):
124135
return len(self.next) > 0

0 commit comments

Comments
 (0)