Skip to content

Commit 8d64854

Browse files
committed
fix tests
1 parent 2be151e commit 8d64854

File tree

2 files changed

+252
-249
lines changed

2 files changed

+252
-249
lines changed
Lines changed: 168 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -1,165 +1,168 @@
1-
import torch
2-
3-
4-
""" Functions """
5-
6-
7-
def _to_dtype_(tensor, dtype=None):
8-
if dtype is None:
9-
return tensor
10-
11-
default_dtype = str(torch.get_default_dtype()).split(".")[-1] # torch.float32 -> "float32"
12-
if dtype == default_dtype:
13-
return tensor
14-
return tensor.to(dtype=getattr(torch, dtype, torch.get_default_dtype()))
15-
16-
17-
def constant(value=0):
18-
return Constant(value=value)
19-
20-
21-
def glorot_normal():
22-
return GlorotNormal()
23-
24-
25-
def glorot_uniform():
26-
return GlorotUniform()
27-
28-
29-
def he_normal():
30-
return HeNormal()
31-
32-
33-
def he_uniform():
34-
return HeUniform()
35-
36-
37-
def ones():
38-
return Ones()
39-
40-
41-
def random_normal(mean=0.0, stddev=1e-6, seed=None):
42-
return RandomNormal(mean=mean, stddev=stddev, seed=seed)
43-
44-
45-
def random_uniform(minval=-0.05, maxval=0.05, seed=None):
46-
return RandomUniform(minval=minval, maxval=maxval, seed=seed)
47-
48-
49-
def truncated_normal(mean=0.0, stddev=1e-6, seed=None):
50-
return TruncatedNormal(mean=mean, stddev=stddev, seed=seed)
51-
52-
53-
def zeros():
54-
return Zeros()
55-
56-
57-
""" Classes """
58-
59-
60-
class Initializer:
61-
def __init__(self, seed=None):
62-
self.seed = seed
63-
64-
@classmethod
65-
def from_config(cls, config):
66-
config.pop("dtype", None)
67-
return cls(**config)
68-
69-
70-
class Constant(Initializer):
71-
def __init__(self, value=0):
72-
self.value = value
73-
super().__init__(seed=None)
74-
75-
def __call__(self, shape, dtype=None, **kwargs):
76-
if hasattr(self.value, "shape") and tuple(self.value.shape) == tuple(shape):
77-
return _to_dtype_(self.value, dtype)
78-
else:
79-
return _to_dtype_(torch.nn.init.constant_(torch.empty(shape), val=self.value), dtype)
80-
81-
def get_config(self):
82-
return {"value": self.value}
83-
84-
85-
class GlorotNormal(Initializer):
86-
def __call__(self, shape, dtype=None, **kwargs):
87-
return _to_dtype_(torch.nn.init.xavier_normal_(torch.empty(shape)), dtype)
88-
89-
90-
class GlorotUniform(Initializer):
91-
def __call__(self, shape, dtype=None, **kwargs):
92-
return _to_dtype_(torch.nn.init.xavier_uniform_(torch.empty(shape)), dtype)
93-
94-
95-
class HeNormal(Initializer):
96-
def __call__(self, shape, dtype=None, **kwargs):
97-
return _to_dtype_(torch.nn.init.kaiming_normal_(torch.empty(shape)), dtype)
98-
99-
100-
class HeUniform(Initializer):
101-
def __call__(self, shape, dtype=None, **kwargs):
102-
return _to_dtype_(torch.nn.init.kaiming_uniform_(torch.empty(shape)), dtype)
103-
104-
105-
class Ones(Initializer):
106-
def __call__(self, shape, dtype=None, **kwargs):
107-
return _to_dtype_(torch.nn.init.ones_(torch.empty(shape)), dtype)
108-
109-
110-
class RandomNormal(Initializer):
111-
def __init__(self, mean=0.0, stddev=0.05, seed=None):
112-
self.mean, self.stddev = mean, stddev
113-
super().__init__(seed=seed)
114-
115-
def __call__(self, shape, dtype=None, **kwargs):
116-
return _to_dtype_(torch.nn.init.normal_(torch.empty(shape), mean=self.mean, std=self.stddev), dtype)
117-
118-
def get_config(self):
119-
return {"mean": self.mean, "stddev": self.stddev}
120-
121-
122-
class RandomUniform(Initializer):
123-
def __init__(self, minval=-0.05, maxval=0.05, seed=None):
124-
self.minval, self.maxval = minval, maxval
125-
super().__init__(seed=seed)
126-
127-
def __call__(self, shape, dtype=None, **kwargs):
128-
return _to_dtype_(torch.nn.init.uniform_(torch.empty(shape), a=self.minval, b=self.maxval), dtype)
129-
130-
def get_config(self):
131-
return {"minval": self.minval, "maxval": self.maxval}
132-
133-
134-
class TruncatedNormal(Initializer):
135-
def __init__(self, mean=0.0, stddev=0.05, seed=None):
136-
self.mean, self.stddev = mean, stddev
137-
super().__init__(seed=seed)
138-
139-
def __call__(self, shape, dtype=None, **kwargs):
140-
return _to_dtype_(torch.nn.init.trunc_normal_(torch.empty(shape), mean=self.mean, std=self.stddev), dtype)
141-
142-
def get_config(self):
143-
return {"mean": self.mean, "stddev": self.stddev}
144-
145-
146-
class VarianceScaling(Initializer):
147-
def __init__(self, scale=1.0, mode="fan_in", distribution="truncated_normal", seed=None):
148-
# scale=2.0, mode="fan_in", distribution="uniform", seed=seed # HeUniform
149-
# scale=2.0, mode="fan_in", distribution="truncated_normal", seed=seed # HeNormal
150-
# scale=1.0, mode="fan_in", distribution="uniform", seed=seed # LecunUniform
151-
# scale=1.0, mode="fan_in", distribution="truncated_normal", seed=seed # LecunNormal
152-
# scale=1.0, mode="fan_avg", distribution="uniform", seed=seed # GlorotUniform
153-
# scale=1.0, mode="fan_avg", distribution="truncated_normal", seed=seed # GlorotNormal
154-
self.scale, self.mode, self.distribution, self.seed = scale, mode, distribution, seed
155-
156-
def __call__(self, shape, dtype=None, **kwargs):
157-
return _to_dtype_(torch.zeros(shape), dtype) # [TODO]
158-
159-
def get_config(self):
160-
return {"scale": self.scale, "mode": self.mode, "distribution": self.distribution}
161-
162-
163-
class Zeros(Initializer):
164-
def __call__(self, shape, dtype=None, **kwargs):
165-
return _to_dtype_(torch.nn.init.zeros_(torch.empty(shape)), dtype)
1+
import torch
2+
3+
4+
""" Functions """
5+
6+
7+
def _to_dtype_(tensor, dtype=None):
8+
if dtype is None:
9+
return tensor
10+
11+
default_dtype = str(torch.get_default_dtype()).split(".")[-1] # torch.float32 -> "float32"
12+
if dtype == default_dtype:
13+
return tensor
14+
return tensor.to(dtype=getattr(torch, dtype, torch.get_default_dtype()))
15+
16+
17+
def constant(value=0):
18+
return Constant(value=value)
19+
20+
21+
def glorot_normal():
22+
return GlorotNormal()
23+
24+
25+
def glorot_uniform():
26+
return GlorotUniform()
27+
28+
29+
def he_normal():
30+
return HeNormal()
31+
32+
33+
def he_uniform():
34+
return HeUniform()
35+
36+
37+
def ones():
38+
return Ones()
39+
40+
41+
def random_normal(mean=0.0, stddev=1e-6, seed=None):
42+
return RandomNormal(mean=mean, stddev=stddev, seed=seed)
43+
44+
45+
def random_uniform(minval=-0.05, maxval=0.05, seed=None):
46+
return RandomUniform(minval=minval, maxval=maxval, seed=seed)
47+
48+
49+
def truncated_normal(mean=0.0, stddev=1e-6, seed=None):
50+
return TruncatedNormal(mean=mean, stddev=stddev, seed=seed)
51+
52+
53+
def zeros():
54+
return Zeros()
55+
56+
57+
""" Classes """
58+
59+
60+
class Initializer:
61+
def __init__(self, seed=None):
62+
self.seed = seed
63+
64+
@classmethod
65+
def from_config(cls, config):
66+
config.pop("dtype", None)
67+
return cls(**config)
68+
69+
def get_config(self):
70+
return {"seed": self.seed}
71+
72+
73+
class Constant(Initializer):
74+
def __init__(self, value=0):
75+
self.value = value
76+
super().__init__(seed=None)
77+
78+
def __call__(self, shape, dtype=None, **kwargs):
79+
if hasattr(self.value, "shape") and tuple(self.value.shape) == tuple(shape):
80+
return _to_dtype_(self.value, dtype)
81+
else:
82+
return _to_dtype_(torch.nn.init.constant_(torch.empty(shape), val=self.value), dtype)
83+
84+
def get_config(self):
85+
return {"value": self.value}
86+
87+
88+
class GlorotNormal(Initializer):
89+
def __call__(self, shape, dtype=None, **kwargs):
90+
return _to_dtype_(torch.nn.init.xavier_normal_(torch.empty(shape)), dtype)
91+
92+
93+
class GlorotUniform(Initializer):
94+
def __call__(self, shape, dtype=None, **kwargs):
95+
return _to_dtype_(torch.nn.init.xavier_uniform_(torch.empty(shape)), dtype)
96+
97+
98+
class HeNormal(Initializer):
99+
def __call__(self, shape, dtype=None, **kwargs):
100+
return _to_dtype_(torch.nn.init.kaiming_normal_(torch.empty(shape)), dtype)
101+
102+
103+
class HeUniform(Initializer):
104+
def __call__(self, shape, dtype=None, **kwargs):
105+
return _to_dtype_(torch.nn.init.kaiming_uniform_(torch.empty(shape)), dtype)
106+
107+
108+
class Ones(Initializer):
109+
def __call__(self, shape, dtype=None, **kwargs):
110+
return _to_dtype_(torch.nn.init.ones_(torch.empty(shape)), dtype)
111+
112+
113+
class RandomNormal(Initializer):
114+
def __init__(self, mean=0.0, stddev=0.05, seed=None):
115+
self.mean, self.stddev = mean, stddev
116+
super().__init__(seed=seed)
117+
118+
def __call__(self, shape, dtype=None, **kwargs):
119+
return _to_dtype_(torch.nn.init.normal_(torch.empty(shape), mean=self.mean, std=self.stddev), dtype)
120+
121+
def get_config(self):
122+
return {"mean": self.mean, "stddev": self.stddev}
123+
124+
125+
class RandomUniform(Initializer):
126+
def __init__(self, minval=-0.05, maxval=0.05, seed=None):
127+
self.minval, self.maxval = minval, maxval
128+
super().__init__(seed=seed)
129+
130+
def __call__(self, shape, dtype=None, **kwargs):
131+
return _to_dtype_(torch.nn.init.uniform_(torch.empty(shape), a=self.minval, b=self.maxval), dtype)
132+
133+
def get_config(self):
134+
return {"minval": self.minval, "maxval": self.maxval}
135+
136+
137+
class TruncatedNormal(Initializer):
138+
def __init__(self, mean=0.0, stddev=0.05, seed=None):
139+
self.mean, self.stddev = mean, stddev
140+
super().__init__(seed=seed)
141+
142+
def __call__(self, shape, dtype=None, **kwargs):
143+
return _to_dtype_(torch.nn.init.trunc_normal_(torch.empty(shape), mean=self.mean, std=self.stddev), dtype)
144+
145+
def get_config(self):
146+
return {"mean": self.mean, "stddev": self.stddev}
147+
148+
149+
class VarianceScaling(Initializer):
150+
def __init__(self, scale=1.0, mode="fan_in", distribution="truncated_normal", seed=None):
151+
# scale=2.0, mode="fan_in", distribution="uniform", seed=seed # HeUniform
152+
# scale=2.0, mode="fan_in", distribution="truncated_normal", seed=seed # HeNormal
153+
# scale=1.0, mode="fan_in", distribution="uniform", seed=seed # LecunUniform
154+
# scale=1.0, mode="fan_in", distribution="truncated_normal", seed=seed # LecunNormal
155+
# scale=1.0, mode="fan_avg", distribution="uniform", seed=seed # GlorotUniform
156+
# scale=1.0, mode="fan_avg", distribution="truncated_normal", seed=seed # GlorotNormal
157+
self.scale, self.mode, self.distribution, self.seed = scale, mode, distribution, seed
158+
159+
def __call__(self, shape, dtype=None, **kwargs):
160+
return _to_dtype_(torch.zeros(shape), dtype) # [TODO]
161+
162+
def get_config(self):
163+
return {"scale": self.scale, "mode": self.mode, "distribution": self.distribution}
164+
165+
166+
class Zeros(Initializer):
167+
def __call__(self, shape, dtype=None, **kwargs):
168+
return _to_dtype_(torch.nn.init.zeros_(torch.empty(shape)), dtype)

0 commit comments

Comments
 (0)