@@ -368,25 +368,22 @@ def set_random_seed():
368
368
369
369
@pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
370
370
@pytest .mark .parametrize ("dtype" , [np .float16 , np .float32 , np .float64 ])
371
- def test_with_beta (dtype ):
371
+ @pytest .mark .parametrize ("r_dim" , [3 , 4 , 5 , 6 , 7 ])
372
+ def test_with_beta (dtype , r_dim ):
372
373
set_random_seed ()
373
374
374
- shape = np .random .choice (range (1 , 30 ), np .random .randint (2 , 7 ), replace = True )
375
+ shape = np .random .choice (range (1 , 30 ), int (r_dim ), replace = True )
376
+ # shape = np.random.choice(range(1, 30), np.random.randint(2, 7), replace=True)
375
377
inputs = np .random .random_sample (shape ).astype (dtype )
378
+ print (shape )
376
379
377
- if len (shape ) == 2 :
378
- axis = channel_idx = 1
379
-
380
- else :
381
- axis = list (
382
- np .random .choice (
383
- range (1 , len (shape ) - 1 ),
384
- np .random .randint (2 , len (shape ) - 1 ),
385
- replace = False ,
386
- )
380
+ axis = list (
381
+ np .random .choice (
382
+ range (1 , len (shape )), np .random .randint (1 , len (shape ) - 1 ), replace = False
387
383
)
388
- channel_idx = list (set (range (len (shape ))) - set (axis ) - set ([0 ]))
389
- channel_idx = int (np .random .choice (channel_idx , 1 ))
384
+ )
385
+ channel_idx = list (set (range (len (shape ))) - set (axis ) - set ([0 ]))
386
+ channel_idx = int (np .random .choice (channel_idx , 1 ))
390
387
391
388
frn = FilterResponseNormalization (
392
389
beta_initializer = "ones" ,
@@ -406,25 +403,20 @@ def test_with_beta(dtype):
406
403
407
404
@pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
408
405
@pytest .mark .parametrize ("dtype" , [np .float16 , np .float32 , np .float64 ])
409
- def test_with_gamma (dtype ):
406
+ @pytest .mark .parametrize ("r_dim" , [3 , 4 , 5 , 6 , 7 ])
407
+ def test_with_gamma (dtype , r_dim ):
410
408
set_random_seed ()
411
409
412
- shape = np .random .choice (range (1 , 30 ), np . random . randint ( 2 , 7 ) , replace = True )
410
+ shape = np .random .choice (range (1 , 30 ), r_dim , replace = True )
413
411
inputs = np .random .random_sample (shape ).astype (dtype )
414
412
415
- if len (shape ) == 2 :
416
- axis = channel_idx = 1
417
-
418
- else :
419
- axis = list (
420
- np .random .choice (
421
- range (1 , len (shape ) - 1 ),
422
- np .random .randint (2 , len (shape ) - 1 ),
423
- replace = False ,
424
- )
413
+ axis = list (
414
+ np .random .choice (
415
+ range (1 , len (shape )), np .random .randint (1 , len (shape ) - 1 ), replace = False
425
416
)
426
- channel_idx = list (set (range (len (shape ))) - set (axis ) - set ([0 ]))
427
- channel_idx = int (np .random .choice (channel_idx , 1 ))
417
+ )
418
+ channel_idx = list (set (range (len (shape ))) - set (axis ) - set ([0 ]))
419
+ channel_idx = int (np .random .choice (channel_idx , 1 ))
428
420
429
421
frn = FilterResponseNormalization (
430
422
beta_initializer = "zeros" ,
@@ -444,25 +436,20 @@ def test_with_gamma(dtype):
444
436
445
437
@pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
446
438
@pytest .mark .parametrize ("dtype" , [np .float16 , np .float32 , np .float64 ])
447
- def test_with_epsilon (dtype ):
439
+ @pytest .mark .parametrize ("r_dim" , [3 , 4 , 5 , 6 , 7 ])
440
+ def test_with_epsilon (dtype , r_dim ):
448
441
set_random_seed ()
449
442
450
- shape = np .random .choice (range (1 , 30 ), np . random . randint ( 2 , 7 ) , replace = True )
443
+ shape = np .random .choice (range (1 , 30 ), r_dim , replace = True )
451
444
inputs = np .random .random_sample (shape ).astype (dtype )
452
445
453
- if len (shape ) == 2 :
454
- axis = channel_idx = 1
455
-
456
- else :
457
- axis = list (
458
- np .random .choice (
459
- range (1 , len (shape ) - 1 ),
460
- np .random .randint (2 , len (shape ) - 1 ),
461
- replace = False ,
462
- )
446
+ axis = list (
447
+ np .random .choice (
448
+ range (1 , len (shape )), np .random .randint (1 , len (shape ) - 1 ), replace = False
463
449
)
464
- channel_idx = list (set (range (len (shape ))) - set (axis ) - set ([0 ]))
465
- channel_idx = int (np .random .choice (channel_idx , 1 ))
450
+ )
451
+ channel_idx = list (set (range (len (shape ))) - set (axis ) - set ([0 ]))
452
+ channel_idx = int (np .random .choice (channel_idx , 1 ))
466
453
467
454
frn = FilterResponseNormalization (
468
455
beta_initializer = tf .keras .initializers .Constant (0.5 ),
@@ -489,27 +476,22 @@ def test_with_epsilon(dtype):
489
476
490
477
@pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
491
478
@pytest .mark .parametrize ("dtype" , [np .float16 , np .float32 , np .float64 ])
492
- def test_keras_model (dtype ):
479
+ @pytest .mark .parametrize ("r_dim" , [3 , 4 , 5 , 6 , 7 ])
480
+ def test_keras_model (dtype , r_dim ):
493
481
set_random_seed ()
494
482
495
- shape = np .random .choice (range (1 , 30 ), np . random . randint ( 2 , 7 ) , replace = True )
483
+ shape = np .random .choice (range (1 , 30 ), r_dim , replace = True )
496
484
random_inputs = np .random .random_sample (shape ).astype (dtype )
497
485
random_labels = np .random .randint (2 , size = (shape [0 ],)).astype (dtype )
498
486
input_layer = tf .keras .layers .Input (shape = tuple (shape [1 :]))
499
487
500
- if len (shape ) == 2 :
501
- axis = channel_idx = 1
502
-
503
- else :
504
- axis = list (
505
- np .random .choice (
506
- range (1 , len (shape ) - 1 ),
507
- np .random .randint (2 , len (shape ) - 1 ),
508
- replace = False ,
509
- )
488
+ axis = list (
489
+ np .random .choice (
490
+ range (1 , len (shape )), np .random .randint (1 , len (shape ) - 1 ), replace = False
510
491
)
511
- channel_idx = list (set (range (len (shape ))) - set (axis ) - set ([0 ]))
512
- channel_idx = int (np .random .choice (channel_idx , 1 ))
492
+ )
493
+ channel_idx = list (set (range (len (shape ))) - set (axis ) - set ([0 ]))
494
+ channel_idx = int (np .random .choice (channel_idx , 1 ))
513
495
514
496
frn = FilterResponseNormalization (
515
497
beta_initializer = "ones" ,
@@ -539,27 +521,22 @@ def test_serialization(dtype):
539
521
540
522
@pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
541
523
@pytest .mark .parametrize ("dtype" , [np .float16 , np .float32 , np .float64 ])
542
- def test_eps_grads (dtype ):
524
+ @pytest .mark .parametrize ("r_dim" , [3 , 4 , 5 , 6 , 7 ])
525
+ def test_eps_grads (dtype , r_dim ):
543
526
set_random_seed ()
544
527
545
- shape = np .random .choice (range (1 , 30 ), np . random . randint ( 2 , 7 ) , replace = True )
528
+ shape = np .random .choice (range (1 , 30 ), r_dim , replace = True )
546
529
random_inputs = np .random .random_sample (shape ).astype (dtype )
547
530
random_labels = np .random .randint (2 , size = (shape [0 ],)).astype (dtype )
548
531
input_layer = tf .keras .layers .Input (shape = tuple (shape [1 :]))
549
532
550
- if len (shape ) == 2 :
551
- axis = channel_idx = 1
552
-
553
- else :
554
- axis = list (
555
- np .random .choice (
556
- range (1 , len (shape ) - 1 ),
557
- np .random .randint (2 , len (shape ) - 1 ),
558
- replace = False ,
559
- )
533
+ axis = list (
534
+ np .random .choice (
535
+ range (1 , len (shape )), np .random .randint (1 , len (shape ) - 1 ), replace = False
560
536
)
561
- channel_idx = list (set (range (len (shape ))) - set (axis ) - set ([0 ]))
562
- channel_idx = int (np .random .choice (channel_idx , 1 ))
537
+ )
538
+ channel_idx = list (set (range (len (shape ))) - set (axis ) - set ([0 ]))
539
+ channel_idx = int (np .random .choice (channel_idx , 1 ))
563
540
564
541
frn = FilterResponseNormalization (
565
542
beta_initializer = "ones" ,
0 commit comments