Skip to content

Commit 60665c0

Browse files
authored
Merge pull request #256 from cnellington/dagma_loss_device_fix
Multi-device DAGMA Training and No-val-set Easy training
2 parents 4527298 + 80cf374 commit 60665c0

File tree

4 files changed

+31
-33
lines changed

4 files changed

+31
-33
lines changed

contextualized/dags/losses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
def dag_loss_dagma_indiv(w, s=1):
7-
M = s * torch.eye(w.shape[-1]) - w * w
7+
M = s * torch.eye(w.shape[-1]).to(w.device) - w * w
88
return w.shape[-1] * np.log(s) - torch.slogdet(M)[1]
99

1010

@@ -18,7 +18,7 @@ def dag_loss_dagma(W, s=1, alpha=0.0, **kwargs):
1818

1919
def dag_loss_poly_indiv(w):
2020
d = w.shape[-1]
21-
return torch.trace((torch.eye(d) + (1 / d) * torch.matmul(w, w))^d) - d
21+
return torch.trace((torch.eye(d).to(w.device) + (1 / d) * torch.matmul(w, w)) ** d) - d
2222

2323

2424
def dag_loss_poly(W, **kwargs):

contextualized/dags/tests.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytorch_lightning.callbacks import LearningRateFinder
1010

1111

12-
from contextualized.dags.lightning_modules import NOTMAD
12+
from contextualized.dags.lightning_modules import NOTMAD, DEFAULT_SS_PARAMS, DEFAULT_ARCH_PARAMS
1313
from contextualized.dags import graph_utils
1414
from contextualized.dags.trainers import GraphTrainer
1515
from contextualized.dags.losses import mse_loss as mse
@@ -37,26 +37,21 @@ def _train(self, model_args, n_epochs):
3737
model = NOTMAD(
3838
self.C.shape[-1],
3939
self.X.shape[-1],
40-
archetype_params={
40+
archetype_loss_params={
4141
"l1": 0.0,
4242
"dag": model_args.get(
4343
"dag",
44-
{
45-
"loss_type": "NOTEARS",
46-
"params": {
47-
"alpha": 1e-1,
48-
"rho": 1e-2,
49-
"h_old": 0.0,
50-
"tol": 0.25,
51-
"use_dynamic_alpha_rho": True,
52-
},
53-
},
44+
DEFAULT_ARCH_PARAMS["dag"],
5445
),
5546
"init_mat": INIT_MAT,
5647
"num_factors": model_args.get("num_factors", 0),
5748
"factor_mat_l1": 0.0,
5849
"num_archetypes": model_args.get("num_archetypes", k),
5950
},
51+
sample_specific_loss_params= {
52+
"l1": 0.0,
53+
"dag": DEFAULT_SS_PARAMS["dag"],
54+
}
6055
)
6156
dataloader = model.dataloader(self.C, self.X, batch_size=1, num_workers=0)
6257
trainer = GraphTrainer(
@@ -181,26 +176,21 @@ def _train(self, model_args, n_epochs):
181176
model = NOTMAD(
182177
self.C.shape[-1],
183178
self.X.shape[-1],
184-
archetype_params={
179+
archetype_loss_params={
185180
"l1": 0.0,
186181
"dag": model_args.get(
187182
"dag",
188-
{
189-
"loss_type": "NOTEARS",
190-
"params": {
191-
"alpha": 1e-1,
192-
"rho": 1e-2,
193-
"h_old": 0.0,
194-
"tol": 0.25,
195-
"use_dynamic_alpha_rho": True,
196-
},
197-
},
183+
DEFAULT_ARCH_PARAMS["dag"],
198184
),
199185
"init_mat": INIT_MAT,
200186
"num_factors": model_args.get("num_factors", 0),
201187
"factor_mat_l1": 0.0,
202188
"num_archetypes": model_args.get("num_archetypes", k),
203189
},
190+
sample_specific_loss_params= {
191+
"l1": 0.0,
192+
"dag": DEFAULT_SS_PARAMS["dag"],
193+
}
204194
)
205195
train_dataloader = model.dataloader(
206196
self.C_train, self.X_train, batch_size=1, num_workers=0

contextualized/easy/wrappers/SKLearnWrapper.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs):
335335
else:
336336
print("X_val not provided, not using the provided C_val.")
337337
if "val_split" in kwargs:
338-
if 0 < kwargs["val_split"] < 1:
338+
if 0 <= kwargs["val_split"] < 1:
339339
val_split = kwargs["val_split"]
340340
else:
341341
print(
@@ -346,15 +346,23 @@ def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs):
346346
else:
347347
val_split = self.default_val_split
348348
if Y is None:
349-
C_train, C_val, X_train, X_val = train_test_split(
350-
C, X, test_size=val_split, shuffle=True
351-
)
349+
if val_split > 0:
350+
C_train, C_val, X_train, X_val = train_test_split(
351+
C, X, test_size=val_split, shuffle=True
352+
)
353+
else:
354+
C_train, X_train = C, X
355+
C_val, X_val = C, X
352356
train_data = [C_train, X_train]
353357
val_data = [C_val, X_val]
354358
else:
355-
C_train, C_val, X_train, X_val, Y_train, Y_val = train_test_split(
356-
C, X, Y, test_size=val_split, shuffle=True
357-
)
359+
if val_split > 0:
360+
C_train, C_val, X_train, X_val, Y_train, Y_val = train_test_split(
361+
C, X, Y, test_size=val_split, shuffle=True
362+
)
363+
else:
364+
C_train, X_train, Y_train = C, X, Y
365+
C_val, X_val, Y_val = C, X, Y
358366
train_data = [C_train, X_train, Y_train]
359367
val_data = [C_val, X_val, Y_val]
360368
return train_data, val_data

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ keywords = [
2626
]
2727
dependencies = [
2828
'lightning>=2.0.0',
29-
'torch>=2.0.0,<2.2.0',
29+
'torch>=2.0.0',
3030
'torchvision>=0.8.0',
3131
'numpy>=1.19.0',
3232
'pandas>=2.0.0',

0 commit comments

Comments
 (0)