Skip to content

Commit 3c48eac

Browse files
committed
updated torch requirement
1 parent c07aa31 commit 3c48eac

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ dependencies:
88
- matplotlib==3.2.2
99
- numpy==1.18.5
1010
- scipy==1.4.1
11-
- torch==1.2.0
11+
- torch==1.6.0

model_arch.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,14 @@ def _forward_oneQ_batch(self, samp_batch):
146146
147147
return output_vec, where 1st col is predicted time
148148
'''
149-
#print(samp_batch)
149+
# print(samp_batch)
150150
feat_vec = samp_batch['feat_vec']
151151
# print(samp_batch['real_node_type'])
152152

153153
# print(samp_batch['node_type'])
154154
# print(feat_vec.shape, print(samp_batch['children_plan']))
155155
input_vec = torch.from_numpy(feat_vec).to(self.device)
156-
#print(samp_batch['node_type'], input_vec)
156+
# print(samp_batch['node_type'], input_vec)
157157
subplans_time = []
158158
for child_plan_dict in samp_batch['children_plan']:
159159
child_output_vec, _ = self._forward_oneQ_batch(child_plan_dict)
@@ -176,15 +176,16 @@ def _forward_oneQ_batch(self, samp_batch):
176176
# pred_time assumed to be the first col
177177

178178
cat_res = torch.cat([pred_time] + subplans_time, axis=1)
179-
#print("cat_res.shape", cat_res.shape)
179+
# print("cat_res.shape", cat_res.shape)
180180
pred_time = torch.sum(cat_res, 1)
181-
#print("pred_time.shape", pred_time.shape)
182-
if self.test_time:
183-
print(samp_batch['node_type'], pred_time, samp_batch['total_time'])
181+
# print("pred_time.shape", pred_time.shape)
182+
183+
# if self.test_time:
184+
# print(samp_batch['node_type'], pred_time, samp_batch['total_time'])
184185

185186
loss = (pred_time -
186187
torch.from_numpy(samp_batch['total_time']).to(self.device)) ** 2
187-
#print("loss.shape", loss.shape)
188+
# print("loss.shape", loss.shape)
188189
self.acc_loss[samp_batch['node_type']].append(loss)
189190

190191
# added to deal with NaN

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
matplotlib==3.2.2
22
numpy==1.18.5
33
scipy==1.4.1
4-
torch==1.2.0
4+
torch==1.6.0

0 commit comments

Comments
 (0)