From eb3a2b71e2f03234641e2faf9c94873961a99fa4 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Sun, 19 Jan 2020 23:04:07 +0000 Subject: [PATCH 1/5] Replaced hardshrink by a pure python version. --- tensorflow_addons/activations/hardshrink.py | 37 +++-- .../custom_ops/activations/BUILD | 4 - .../activations/cc/kernels/hardshrink_op.cc | 81 ---------- .../activations/cc/kernels/hardshrink_op.h | 144 ------------------ .../cc/kernels/hardshrink_op_gpu.cu.cc | 38 ----- .../activations/cc/ops/hardshrink_op.cc | 41 ----- 6 files changed, 27 insertions(+), 318 deletions(-) delete mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc delete mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h delete mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc delete mode 100644 tensorflow_addons/custom_ops/activations/cc/ops/hardshrink_op.cc diff --git a/tensorflow_addons/activations/hardshrink.py b/tensorflow_addons/activations/hardshrink.py index fabc217c27..32f8c0b0e3 100644 --- a/tensorflow_addons/activations/hardshrink.py +++ b/tensorflow_addons/activations/hardshrink.py @@ -14,9 +14,33 @@ # ============================================================================== import tensorflow as tf -from tensorflow_addons.utils.resource_loader import LazySO -_activation_so = LazySO("custom_ops/activations/_activation_ops.so") + +def _hardshrink(x, lower, upper): + mask_lower = x < lower + mask_upper = upper < x + mask = tf.logical_or(mask_lower, mask_upper) + mask = tf.cast(mask, tf.float32) + return x * mask + + +def compile_with_xla(func, dtype): + compiled = tf.function( + func, + input_signature=(tf.TensorSpec(shape=None, dtype=dtype), + tf.TensorSpec(shape=tuple(), dtype=dtype), + tf.TensorSpec(shape=tuple(), dtype=dtype)), + autograph=False, + experimental_compile=True + ) + return compiled + + +supported_dtypes = [tf.float16, tf.float32, tf.float64] + +function_dispatch = {} +for dtype in supported_dtypes: + function_dispatch[dtype] = compile_with_xla(_hardshrink, dtype) @tf.keras.utils.register_keras_serializable(package='Addons') @@ -35,11 +59,4 @@ def hardshrink(x, lower=-0.5, upper=0.5): A `Tensor`. Has the same type as `x`. """ x = tf.convert_to_tensor(x) - return _activation_so.ops.addons_hardshrink(x, lower, upper) - - -@tf.RegisterGradient("Addons>Hardshrink") -def _hardshrink_grad(op, grad): - return _activation_so.ops.addons_hardshrink_grad(grad, op.inputs[0], - op.get_attr("lower"), - op.get_attr("upper")) + return function_dispatch[x.dtype](x, lower, upper) diff --git a/tensorflow_addons/custom_ops/activations/BUILD b/tensorflow_addons/custom_ops/activations/BUILD index 64150404b0..effee2cf93 100644 --- a/tensorflow_addons/custom_ops/activations/BUILD +++ b/tensorflow_addons/custom_ops/activations/BUILD @@ -9,8 +9,6 @@ custom_op_library( srcs = [ "cc/kernels/gelu_op.cc", "cc/kernels/gelu_op.h", - "cc/kernels/hardshrink_op.cc", - "cc/kernels/hardshrink_op.h", "cc/kernels/lisht_op.cc", "cc/kernels/lisht_op.h", "cc/kernels/mish_op.cc", @@ -29,8 +27,6 @@ custom_op_library( cuda_srcs = [ "cc/kernels/gelu_op.h", "cc/kernels/gelu_op_gpu.cu.cc", - "cc/kernels/hardshrink_op.h", - "cc/kernels/hardshrink_op_gpu.cu.cc", "cc/kernels/lisht_op.h", "cc/kernels/lisht_op_gpu.cu.cc", "cc/kernels/mish_op.h", diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc deleted file mode 100644 index 8563d81f64..0000000000 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - -namespace tensorflow { -namespace addons { - -using CPUDevice = Eigen::ThreadPoolDevice; - -#define REGISTER_HARDSHRINK_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Addons>Hardshrink").Device(DEVICE_CPU).TypeConstraint("T"), \ - HardshrinkOp); \ - REGISTER_KERNEL_BUILDER(Name("Addons>HardshrinkGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ - HardshrinkGradOp); - -// Hardshrink only makes sense with floating points. -TF_CALL_GPU_NUMBER_TYPES(REGISTER_HARDSHRINK_KERNELS); -#undef REGISTER_HARDSHRINK_KERNELS - -#if GOOGLE_CUDA - -using GPUDevice = Eigen::GpuDevice; - -// Forward declarations of the functor specializations for GPU. -namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void Hardshrink::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor features, T lower, \ - T upper, typename TTypes::Tensor activations); \ - extern template struct Hardshrink; \ - \ - template <> \ - void HardshrinkGrad::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor gradients, \ - typename TTypes::ConstTensor features, T lower, T upper, \ - typename TTypes::Tensor backprops); \ - extern template struct HardshrinkGrad; - -TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); -#undef DECLARE_GPU_SPEC -} // namespace functor - -// Registration of the GPU implementations. -#define REGISTER_HARDSHRINK_GPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Addons>Hardshrink").Device(DEVICE_GPU).TypeConstraint("T"), \ - HardshrinkOp); \ - REGISTER_KERNEL_BUILDER(Name("Addons>HardshrinkGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T"), \ - HardshrinkGradOp); - -TF_CALL_GPU_NUMBER_TYPES(REGISTER_HARDSHRINK_GPU_KERNELS); -#undef REGISTER_HARDSHRINK_GPU_KERNELS - -#endif // GOOGLE_CUDA - -} // namespace addons -} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h deleted file mode 100644 index 92313dc0eb..0000000000 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_HARDSHRINK_OP_H_ -#define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_HARDSHRINK_OP_H_ - -#define EIGEN_USE_THREADS - -#include "tensorflow/core/framework/numeric_op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/lib/core/errors.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - -namespace tensorflow { -namespace addons { - -namespace functor { - -// Functor used by HardshrinkOp to do the computations. -template -struct Hardshrink { - // Computes Hardshrink activation. - // - // features: any shape. - // lower: the lower bound for setting values to zeros. - // upper: the upper bound for setting values to zeros. - // activations: same shape as "features". - void operator()(const Device& d, typename TTypes::ConstTensor features, - T lower, T upper, typename TTypes::Tensor activations) { - activations.device(d) = - (features < lower || features > upper) - .select(features, features.constant(static_cast(0))); - } -}; - -// Functor used by HardshrinkGradOp to do the computations. -template -struct HardshrinkGrad { - // Computes HardshrinkGrad backprops. - // - // gradients: gradients backpropagated to the Hardshink op. - // features: inputs that were passed to the Hardshrink op. - // lower: the lower bound for setting values to zeros. - // upper: the upper bound for setting values to zeros. - // backprops: gradients to backpropagate to the Hardshrink inputs. - void operator()(const Device& d, typename TTypes::ConstTensor gradients, - typename TTypes::ConstTensor features, T lower, T upper, - typename TTypes::Tensor backprops) { - backprops.device(d) = - (features < lower || features > upper) - .select(gradients, features.constant(static_cast(0))); - } -}; - -} // namespace functor - -template -class HardshrinkOp : public UnaryElementWiseOp> { - public: - explicit HardshrinkOp(OpKernelConstruction* context) - : UnaryElementWiseOp>::UnaryElementWiseOp( - context) { - float lower, upper; - OP_REQUIRES_OK(context, context->GetAttr("lower", &lower)); - OP_REQUIRES_OK(context, context->GetAttr("upper", &upper)); - lower_ = static_cast(lower); - upper_ = static_cast(upper); - - OP_REQUIRES( - context, lower_ <= upper_, - errors::InvalidArgument("lower must be less than or equal to upper.")); - } - - void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { - functor::Hardshrink functor; - functor(context->eigen_device(), input.flat(), lower_, upper_, - output->flat()); - } - - private: - T lower_; - T upper_; -}; - -template -class HardshrinkGradOp - : public BinaryElementWiseOp> { - public: - explicit HardshrinkGradOp(OpKernelConstruction* context) - : BinaryElementWiseOp< - T, HardshrinkGradOp>::BinaryElementWiseOp(context) { - float lower, upper; - OP_REQUIRES_OK(context, context->GetAttr("lower", &lower)); - OP_REQUIRES_OK(context, context->GetAttr("upper", &upper)); - lower_ = static_cast(lower); - upper_ = static_cast(upper); - - OP_REQUIRES( - context, lower_ <= upper_, - errors::InvalidArgument("lower must be less than or equal to upper.")); - } - - void OperateNoTemplate(OpKernelContext* context, const Tensor& g, - const Tensor& a, T lower, T upper, Tensor* output); - - template - void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, - Tensor* output) { - OperateNoTemplate(context, g, a, lower_, upper_, output); - } - - private: - T lower_; - T upper_; -}; - -template -void HardshrinkGradOp::OperateNoTemplate(OpKernelContext* context, - const Tensor& g, - const Tensor& a, T lower, - T upper, Tensor* output) { - functor::HardshrinkGrad functor; - functor(context->eigen_device(), g.flat(), a.flat(), lower, - upper, output->flat()); -} - -} // namespace addons -} // namespace tensorflow - -#undef EIGEN_USE_THREADS - -#endif // TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_HARDSHRINK_OP_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc deleted file mode 100644 index 9b4d2a9e83..0000000000 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#if GOOGLE_CUDA - -#define EIGEN_USE_GPU - -#include "tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h" -#include "tensorflow/core/framework/register_types.h" -#include "third_party/eigen3/Eigen/Core" - -namespace tensorflow { -namespace addons { - -using GPUDevice = Eigen::GpuDevice; - -#define DEFINE_GPU_KERNELS(T) \ - template struct functor::Hardshrink; \ - template struct functor::HardshrinkGrad; - -TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); - -} // namespace addons -} // namespace tensorflow - -#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/hardshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/hardshrink_op.cc deleted file mode 100644 index 5eecf7e2e5..0000000000 --- a/tensorflow_addons/custom_ops/activations/cc/ops/hardshrink_op.cc +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { -namespace addons { - -REGISTER_OP("Addons>Hardshrink") - .Input("features: T") - .Output("activations: T") - .Attr("T: {half, float, double}") - .Attr("lower: float = -0.5") - .Attr("upper: float = 0.5") - .SetShapeFn(shape_inference::UnchangedShape); - -REGISTER_OP("Addons>HardshrinkGrad") - .Input("gradients: T") - .Input("features: T") - .Output("backprops: T") - .Attr("T: {half, float, double}") - .Attr("lower: float = -0.5") - .Attr("upper: float = 0.5") - .SetShapeFn(shape_inference::MergeBothInputsShapeFn); - -} // namespace addons -} // namespace tensorflow From 4b37fb24ba2d82acb4cbfd56138bab6e0bf4db0b Mon Sep 17 00:00:00 2001 From: Gabriel de Marmiesse Date: Mon, 20 Jan 2020 01:23:38 +0100 Subject: [PATCH 2/5] Update hardshrink.py --- tensorflow_addons/activations/hardshrink.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/activations/hardshrink.py b/tensorflow_addons/activations/hardshrink.py index 32f8c0b0e3..e1d67d8cc2 100644 --- a/tensorflow_addons/activations/hardshrink.py +++ b/tensorflow_addons/activations/hardshrink.py @@ -20,7 +20,7 @@ def _hardshrink(x, lower, upper): mask_lower = x < lower mask_upper = upper < x mask = tf.logical_or(mask_lower, mask_upper) - mask = tf.cast(mask, tf.float32) + mask = tf.cast(mask, x.dtype) return x * mask From f6d3b3b9bb431cc1139a028c367db5a421605e6a Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Mon, 20 Jan 2020 15:35:33 +0100 Subject: [PATCH 3/5] Forgot a line. --- tensorflow_addons/custom_ops/activations/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow_addons/custom_ops/activations/BUILD b/tensorflow_addons/custom_ops/activations/BUILD index effee2cf93..b59469f860 100644 --- a/tensorflow_addons/custom_ops/activations/BUILD +++ b/tensorflow_addons/custom_ops/activations/BUILD @@ -18,7 +18,6 @@ custom_op_library( "cc/kernels/tanhshrink_op.cc", "cc/kernels/tanhshrink_op.h", "cc/ops/gelu_op.cc", - "cc/ops/hardshrink_op.cc", "cc/ops/lisht_op.cc", "cc/ops/mish_op.cc", "cc/ops/softshrink_op.cc", From f65e31830ab41359e645811ca930318dc9c3e3be Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Mon, 20 Jan 2020 17:37:30 +0100 Subject: [PATCH 4/5] Added a check. --- tensorflow_addons/activations/hardshrink.py | 4 ++++ tensorflow_addons/activations/hardshrink_test.py | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/activations/hardshrink.py b/tensorflow_addons/activations/hardshrink.py index e1d67d8cc2..50e5838fb2 100644 --- a/tensorflow_addons/activations/hardshrink.py +++ b/tensorflow_addons/activations/hardshrink.py @@ -58,5 +58,9 @@ def hardshrink(x, lower=-0.5, upper=0.5): Returns: A `Tensor`. Has the same type as `x`. """ + if lower > upper: + raise ValueError("The value of lower is {} and should" + " not be higher than the value " + "variable upper, which is {} .".format(lower, upper)) x = tf.convert_to_tensor(x) return function_dispatch[x.dtype](x, lower, upper) diff --git a/tensorflow_addons/activations/hardshrink_test.py b/tensorflow_addons/activations/hardshrink_test.py index 831d76e5c4..fa6e585e36 100644 --- a/tensorflow_addons/activations/hardshrink_test.py +++ b/tensorflow_addons/activations/hardshrink_test.py @@ -24,8 +24,7 @@ @test_utils.run_all_in_graph_and_eager_modes class HardshrinkTest(tf.test.TestCase, parameterized.TestCase): def test_invalid(self): - with self.assertRaisesOpError( - "lower must be less than or equal to upper."): # pylint: disable=bad-continuation + with self.assertRaises(ValueError): y = hardshrink(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0) self.evaluate(y) From 0ed7f87a4492ad55847f961a2b7c0ec09da1bcfd Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Wed, 22 Jan 2020 16:40:14 +0100 Subject: [PATCH 5/5] Used a version without decorator. --- tensorflow_addons/activations/hardshrink.py | 33 ++++----------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/tensorflow_addons/activations/hardshrink.py b/tensorflow_addons/activations/hardshrink.py index 50e5838fb2..7512d6d62d 100644 --- a/tensorflow_addons/activations/hardshrink.py +++ b/tensorflow_addons/activations/hardshrink.py @@ -16,33 +16,6 @@ import tensorflow as tf -def _hardshrink(x, lower, upper): - mask_lower = x < lower - mask_upper = upper < x - mask = tf.logical_or(mask_lower, mask_upper) - mask = tf.cast(mask, x.dtype) - return x * mask - - -def compile_with_xla(func, dtype): - compiled = tf.function( - func, - input_signature=(tf.TensorSpec(shape=None, dtype=dtype), - tf.TensorSpec(shape=tuple(), dtype=dtype), - tf.TensorSpec(shape=tuple(), dtype=dtype)), - autograph=False, - experimental_compile=True - ) - return compiled - - -supported_dtypes = [tf.float16, tf.float32, tf.float64] - -function_dispatch = {} -for dtype in supported_dtypes: - function_dispatch[dtype] = compile_with_xla(_hardshrink, dtype) - - @tf.keras.utils.register_keras_serializable(package='Addons') def hardshrink(x, lower=-0.5, upper=0.5): """Hard shrink function. @@ -63,4 +36,8 @@ def hardshrink(x, lower=-0.5, upper=0.5): " not be higher than the value " "variable upper, which is {} .".format(lower, upper)) x = tf.convert_to_tensor(x) - return function_dispatch[x.dtype](x, lower, upper) + mask_lower = x < lower + mask_upper = upper < x + mask = tf.logical_or(mask_lower, mask_upper) + mask = tf.cast(mask, x.dtype) + return x * mask