Skip to content

Commit 55a3001

Browse files
committed
Added tests for FRN on N-D tensors, fixes tensorflow#1441
* Tests for 3D-7D tensors * Error raised when tensors aren't at least 3D
1 parent 7538a77 commit 55a3001

File tree

2 files changed

+56
-80
lines changed

2 files changed

+56
-80
lines changed

tensorflow_addons/layers/normalizations.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ class FilterResponseNormalization(tf.keras.layers.Layer):
354354
(tuple of integers, does not include the samples axis)
355355
when using this layer as the first layer in a model. This layer supports
356356
arbitrary tensors with the following assumptions:
357-
- Expected input tensor to be at least 2D.
357+
- Expected input tensor to be at least 3D.
358358
- 0th index in tensor shape is expected to be the batch dimension.
359359
360360
Output shape
@@ -384,7 +384,7 @@ def __init__(
384384
):
385385
super().__init__(name=name, **kwargs)
386386
self.channel_idx = channel_idx
387-
self.epsilon = tf.math.abs(tf.cast(epsilon, dtype=self.dtype))
387+
self.epsilon = epsilon
388388
self.beta_initializer = tf.keras.initializers.get(beta_initializer)
389389
self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
390390
self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer)
@@ -421,7 +421,7 @@ def build(self, input_shape):
421421
super().build(input_shape)
422422

423423
def call(self, inputs):
424-
epsilon = self.epsilon
424+
epsilon = tf.math.abs(tf.cast(self.epsilon, dtype=self.dtype))
425425
if self.use_eps_learned:
426426
epsilon += tf.math.abs(self.eps_learned)
427427
nu2 = tf.reduce_mean(tf.square(inputs), axis=self.axis, keepdims=True)
@@ -484,11 +484,6 @@ def _check_axis(self, axis):
484484
self.axis = axis
485485

486486
elif isinstance(axis, int):
487-
if abs(axis) != 1 or abs(self.channel_idx) != 1:
488-
raise ValueError(
489-
"Expected index for 2D is -1/1 but got {}".format(axis)
490-
)
491-
492487
self.axis = [axis]
493488

494489
else:
@@ -500,8 +495,12 @@ def _check_axis(self, axis):
500495
raise ValueError("Duplicate axis: %s" % self.axis)
501496

502497
def _check_if_input_shape_is_none(self, input_shape):
503-
dim1, dim2 = input_shape[self.axis[0]], input_shape[self.axis[1]]
504-
if dim1 is None or dim2 is None:
498+
dims = [input_shape[i] for i in self.axis]
499+
500+
if len(input_shape) < 3:
501+
raise ValueError("Expected input tensor to be at least 3D.")
502+
503+
if None in dims:
505504
raise ValueError(
506505
"""Axis {} of input tensor should have a defined dimension but
507506
the layer received an input with shape {}.""".format(

tensorflow_addons/layers/tests/normalizations_test.py

Lines changed: 47 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -368,25 +368,22 @@ def set_random_seed():
368368

369369
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
370370
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
371-
def test_with_beta(dtype):
371+
@pytest.mark.parametrize("r_dim", [3, 4, 5, 6, 7])
372+
def test_with_beta(dtype, r_dim):
372373
set_random_seed()
373374

374-
shape = np.random.choice(range(1, 30), np.random.randint(2, 7), replace=True)
375+
shape = np.random.choice(range(1, 30), int(r_dim), replace=True)
376+
# shape = np.random.choice(range(1, 30), np.random.randint(2, 7), replace=True)
375377
inputs = np.random.random_sample(shape).astype(dtype)
378+
print(shape)
376379

377-
if len(shape) == 2:
378-
axis = channel_idx = 1
379-
380-
else:
381-
axis = list(
382-
np.random.choice(
383-
range(1, len(shape) - 1),
384-
np.random.randint(2, len(shape) - 1),
385-
replace=False,
386-
)
380+
axis = list(
381+
np.random.choice(
382+
range(1, len(shape)), np.random.randint(1, len(shape) - 1), replace=False
387383
)
388-
channel_idx = list(set(range(len(shape))) - set(axis) - set([0]))
389-
channel_idx = int(np.random.choice(channel_idx, 1))
384+
)
385+
channel_idx = list(set(range(len(shape))) - set(axis) - set([0]))
386+
channel_idx = int(np.random.choice(channel_idx, 1))
390387

391388
frn = FilterResponseNormalization(
392389
beta_initializer="ones",
@@ -406,25 +403,20 @@ def test_with_beta(dtype):
406403

407404
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
408405
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
409-
def test_with_gamma(dtype):
406+
@pytest.mark.parametrize("r_dim", [3, 4, 5, 6, 7])
407+
def test_with_gamma(dtype, r_dim):
410408
set_random_seed()
411409

412-
shape = np.random.choice(range(1, 30), np.random.randint(2, 7), replace=True)
410+
shape = np.random.choice(range(1, 30), r_dim, replace=True)
413411
inputs = np.random.random_sample(shape).astype(dtype)
414412

415-
if len(shape) == 2:
416-
axis = channel_idx = 1
417-
418-
else:
419-
axis = list(
420-
np.random.choice(
421-
range(1, len(shape) - 1),
422-
np.random.randint(2, len(shape) - 1),
423-
replace=False,
424-
)
413+
axis = list(
414+
np.random.choice(
415+
range(1, len(shape)), np.random.randint(1, len(shape) - 1), replace=False
425416
)
426-
channel_idx = list(set(range(len(shape))) - set(axis) - set([0]))
427-
channel_idx = int(np.random.choice(channel_idx, 1))
417+
)
418+
channel_idx = list(set(range(len(shape))) - set(axis) - set([0]))
419+
channel_idx = int(np.random.choice(channel_idx, 1))
428420

429421
frn = FilterResponseNormalization(
430422
beta_initializer="zeros",
@@ -444,25 +436,20 @@ def test_with_gamma(dtype):
444436

445437
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
446438
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
447-
def test_with_epsilon(dtype):
439+
@pytest.mark.parametrize("r_dim", [3, 4, 5, 6, 7])
440+
def test_with_epsilon(dtype, r_dim):
448441
set_random_seed()
449442

450-
shape = np.random.choice(range(1, 30), np.random.randint(2, 7), replace=True)
443+
shape = np.random.choice(range(1, 30), r_dim, replace=True)
451444
inputs = np.random.random_sample(shape).astype(dtype)
452445

453-
if len(shape) == 2:
454-
axis = channel_idx = 1
455-
456-
else:
457-
axis = list(
458-
np.random.choice(
459-
range(1, len(shape) - 1),
460-
np.random.randint(2, len(shape) - 1),
461-
replace=False,
462-
)
446+
axis = list(
447+
np.random.choice(
448+
range(1, len(shape)), np.random.randint(1, len(shape) - 1), replace=False
463449
)
464-
channel_idx = list(set(range(len(shape))) - set(axis) - set([0]))
465-
channel_idx = int(np.random.choice(channel_idx, 1))
450+
)
451+
channel_idx = list(set(range(len(shape))) - set(axis) - set([0]))
452+
channel_idx = int(np.random.choice(channel_idx, 1))
466453

467454
frn = FilterResponseNormalization(
468455
beta_initializer=tf.keras.initializers.Constant(0.5),
@@ -489,27 +476,22 @@ def test_with_epsilon(dtype):
489476

490477
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
491478
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
492-
def test_keras_model(dtype):
479+
@pytest.mark.parametrize("r_dim", [3, 4, 5, 6, 7])
480+
def test_keras_model(dtype, r_dim):
493481
set_random_seed()
494482

495-
shape = np.random.choice(range(1, 30), np.random.randint(2, 7), replace=True)
483+
shape = np.random.choice(range(1, 30), r_dim, replace=True)
496484
random_inputs = np.random.random_sample(shape).astype(dtype)
497485
random_labels = np.random.randint(2, size=(shape[0],)).astype(dtype)
498486
input_layer = tf.keras.layers.Input(shape=tuple(shape[1:]))
499487

500-
if len(shape) == 2:
501-
axis = channel_idx = 1
502-
503-
else:
504-
axis = list(
505-
np.random.choice(
506-
range(1, len(shape) - 1),
507-
np.random.randint(2, len(shape) - 1),
508-
replace=False,
509-
)
488+
axis = list(
489+
np.random.choice(
490+
range(1, len(shape)), np.random.randint(1, len(shape) - 1), replace=False
510491
)
511-
channel_idx = list(set(range(len(shape))) - set(axis) - set([0]))
512-
channel_idx = int(np.random.choice(channel_idx, 1))
492+
)
493+
channel_idx = list(set(range(len(shape))) - set(axis) - set([0]))
494+
channel_idx = int(np.random.choice(channel_idx, 1))
513495

514496
frn = FilterResponseNormalization(
515497
beta_initializer="ones",
@@ -539,27 +521,22 @@ def test_serialization(dtype):
539521

540522
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
541523
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
542-
def test_eps_grads(dtype):
524+
@pytest.mark.parametrize("r_dim", [3, 4, 5, 6, 7])
525+
def test_eps_grads(dtype, r_dim):
543526
set_random_seed()
544527

545-
shape = np.random.choice(range(1, 30), np.random.randint(2, 7), replace=True)
528+
shape = np.random.choice(range(1, 30), r_dim, replace=True)
546529
random_inputs = np.random.random_sample(shape).astype(dtype)
547530
random_labels = np.random.randint(2, size=(shape[0],)).astype(dtype)
548531
input_layer = tf.keras.layers.Input(shape=tuple(shape[1:]))
549532

550-
if len(shape) == 2:
551-
axis = channel_idx = 1
552-
553-
else:
554-
axis = list(
555-
np.random.choice(
556-
range(1, len(shape) - 1),
557-
np.random.randint(2, len(shape) - 1),
558-
replace=False,
559-
)
533+
axis = list(
534+
np.random.choice(
535+
range(1, len(shape)), np.random.randint(1, len(shape) - 1), replace=False
560536
)
561-
channel_idx = list(set(range(len(shape))) - set(axis) - set([0]))
562-
channel_idx = int(np.random.choice(channel_idx, 1))
537+
)
538+
channel_idx = list(set(range(len(shape))) - set(axis) - set([0]))
539+
channel_idx = int(np.random.choice(channel_idx, 1))
563540

564541
frn = FilterResponseNormalization(
565542
beta_initializer="ones",

0 commit comments

Comments
 (0)