From 8cf4fe936d35fc5a07623b8145e0082212db359e Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 1 Jun 2021 11:18:31 -0400 Subject: [PATCH 1/5] Move Ops from CTOR to call method --- .../activations/AbstractActivation.java | 46 ++++++ .../framework/activations/Activation.java | 45 +----- .../tensorflow/framework/activations/ELU.java | 34 ++--- .../framework/activations/Exponential.java | 23 +-- .../framework/activations/HardSigmoid.java | 31 ++-- .../framework/activations/Linear.java | 18 +-- .../framework/activations/ReLU.java | 30 ++-- .../framework/activations/SELU.java | 21 +-- .../framework/activations/Sigmoid.java | 21 +-- .../framework/activations/Softmax.java | 22 +-- .../framework/activations/Softplus.java | 21 +-- .../framework/activations/Softsign.java | 21 +-- .../framework/activations/Swish.java | 11 +- .../framework/activations/Tanh.java | 14 +- .../constraints/AbstractConstraint.java | 89 ++++++++++++ .../framework/constraints/Constraint.java | 88 +----------- .../framework/constraints/MaxNorm.java | 30 ++-- .../framework/constraints/MinMaxNorm.java | 30 ++-- .../framework/constraints/NonNeg.java | 15 +- .../framework/constraints/UnitNorm.java | 31 ++-- .../initializers/BaseInitializer.java | 21 ++- .../framework/initializers/Constant.java | 31 ++-- .../framework/initializers/Glorot.java | 12 +- .../tensorflow/framework/initializers/He.java | 16 +-- .../framework/initializers/Identity.java | 30 ++-- .../framework/initializers/Initializer.java | 7 +- .../framework/initializers/LeCun.java | 15 +- .../framework/initializers/Ones.java | 20 +-- .../framework/initializers/Orthogonal.java | 21 +-- .../framework/initializers/RandomNormal.java | 26 ++-- .../framework/initializers/RandomUniform.java | 31 ++-- .../initializers/TruncatedNormal.java | 23 +-- .../initializers/VarianceScaling.java | 32 ++--- .../framework/initializers/Zeros.java | 17 +-- .../framework/losses/BinaryCrossentropy.java | 79 +++++----- .../losses/CategoricalCrossentropy.java | 135 ++++++++---------- .../framework/losses/CategoricalHinge.java | 40 +++--- .../framework/losses/CosineSimilarity.java | 115 +++++++-------- .../tensorflow/framework/losses/Hinge.java | 48 +++---- .../tensorflow/framework/losses/Huber.java | 61 ++++---- .../framework/losses/KLDivergence.java | 50 ++++--- .../tensorflow/framework/losses/LogCosh.java | 54 ++++--- .../org/tensorflow/framework/losses/Loss.java | 78 +--------- .../framework/losses/MeanAbsoluteError.java | 44 +++--- .../losses/MeanAbsolutePercentageError.java | 45 +++--- .../framework/losses/MeanSquaredError.java | 44 +++--- .../losses/MeanSquaredLogarithmicError.java | 44 +++--- .../tensorflow/framework/losses/Poisson.java | 54 ++++--- .../framework/losses/Reduction.java | 2 +- .../losses/SparseCategoricalCrossentropy.java | 73 +++++----- .../framework/losses/SquaredHinge.java | 53 ++++--- .../framework/losses/impl/AbstractLoss.java | 89 ++++++++++++ .../org/tensorflow/framework/metrics/AUC.java | 95 ++++++------ .../framework/metrics/Accuracy.java | 8 +- .../framework/metrics/BinaryAccuracy.java | 8 +- .../metrics/CategoricalAccuracy.java | 19 ++- .../metrics/CategoricalCrossentropy.java | 20 ++- .../framework/metrics/FalseNegatives.java | 42 +++--- .../framework/metrics/FalsePositives.java | 42 +++--- .../tensorflow/framework/metrics/MeanIoU.java | 14 +- .../framework/metrics/MeanRelativeError.java | 11 +- .../framework/metrics/MeanTensor.java | 4 +- .../framework/metrics/Precision.java | 71 +++++---- .../framework/metrics/PrecisionAtRecall.java | 7 +- .../tensorflow/framework/metrics/Recall.java | 26 ++-- .../framework/metrics/RecallAtPrecision.java | 4 +- .../metrics/RootMeanSquaredError.java | 3 +- .../metrics/SensitivityAtSpecificity.java | 20 +-- .../metrics/SparseCategoricalAccuracy.java | 6 +- .../metrics/SpecificityAtSensitivity.java | 20 +-- .../org/tensorflow/framework/metrics/Sum.java | 8 +- .../metrics/TopKCategoricalAccuracy.java | 4 +- .../framework/metrics/TrueNegatives.java | 42 +++--- .../framework/metrics/TruePositives.java | 42 +++--- .../impl/ConfusionMatrixConditionCount.java | 26 ++-- .../framework/metrics/impl/LossMetric.java | 2 +- .../metrics/impl/MeanMetricWrapper.java | 8 +- .../framework/metrics/impl/MetricsHelper.java | 116 +++++++-------- .../impl/SensitivitySpecificityBase.java | 6 +- .../framework/metrics/impl/SetsOps.java | 24 ++-- .../framework/metrics/impl/SymbolicShape.java | 45 +++++- .../metrics/impl/WeightsBroadcastOps.java | 34 ++--- .../regularizers/AbstractRegularizer.java | 63 ++++++++ .../tensorflow/framework/regularizers/L1.java | 33 +++-- .../framework/regularizers/L1L2.java | 38 ++--- .../tensorflow/framework/regularizers/L2.java | 33 +++-- .../framework/regularizers/Regularizer.java | 67 +-------- .../regularizers/RegularizerLoss.java | 31 ++-- .../framework/activations/ELUTest.java | 33 +---- .../activations/ExponentialTest.java | 28 +--- .../activations/HardSigmoidTest.java | 28 +--- .../framework/activations/LinearTest.java | 28 +--- .../framework/activations/ReLUTest.java | 58 ++++---- .../framework/activations/SELUTest.java | 28 +--- .../framework/activations/SigmoidTest.java | 27 +--- .../framework/activations/SoftmaxTest.java | 47 ++---- .../framework/activations/SoftplusTest.java | 24 +--- .../framework/activations/SoftsignTest.java | 24 +--- .../framework/activations/SwishTest.java | 28 +--- .../framework/activations/TanhTest.java | 24 +--- .../framework/constraints/MaxNormTest.java | 8 +- .../framework/constraints/MinMaxNormTest.java | 4 +- .../framework/constraints/NonNegTest.java | 8 +- .../framework/constraints/UnitNormTest.java | 8 +- .../framework/initializers/ConstantTest.java | 66 ++++----- .../framework/initializers/GlorotTest.java | 57 ++++---- .../framework/initializers/HeTest.java | 57 ++++---- .../framework/initializers/IdentityTest.java | 34 ++--- .../framework/initializers/LeCunTest.java | 50 +++---- .../framework/initializers/OnesTest.java | 72 +++++----- .../initializers/OrthogonalTest.java | 34 ++--- .../initializers/RandomNormalTest.java | 33 ++--- .../initializers/RandomUniformTest.java | 38 ++--- .../initializers/TruncatedNormalTest.java | 33 ++--- .../initializers/VarianceScalingTest.java | 73 +++------- .../framework/initializers/ZerosTest.java | 72 +++++----- .../losses/BinaryCrossentropyTest.java | 54 ++++--- .../losses/CategoricalCrossentropyTest.java | 56 +++++--- .../losses/CategoricalHingeTest.java | 32 +++-- .../losses/CosineSimilarityTest.java | 35 +++-- .../framework/losses/HingeTest.java | 30 ++-- .../framework/losses/HuberTest.java | 30 ++-- .../framework/losses/KLDivergenceTest.java | 25 ++-- .../framework/losses/LogCoshTest.java | 25 ++-- .../losses/MeanAbsoluteErrorTest.java | 45 +++--- .../MeanAbsolutePercentageErrorTest.java | 40 +++--- .../losses/MeanSquaredErrorTest.java | 45 +++--- .../MeanSquaredLogarithmicErrorTest.java | 45 +++--- .../framework/losses/PoissonTest.java | 25 ++-- .../SparseCategoricalCrossentropyTest.java | 50 ++++--- .../framework/losses/SquaredHingeTest.java | 30 ++-- .../optimizers/GradientDescentTest.java | 48 ++++--- .../framework/regularizers/L1L2Test.java | 38 ++--- .../framework/regularizers/L1Test.java | 22 +-- .../framework/regularizers/L2Test.java | 22 +-- .../regularizers/RegularizerLossTest.java | 8 +- 136 files changed, 2278 insertions(+), 2544 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/activations/AbstractActivation.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/AbstractLoss.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/AbstractRegularizer.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/AbstractActivation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/AbstractActivation.java new file mode 100644 index 00000000000..335b8697273 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/AbstractActivation.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Abstract base class for Activations */ +public abstract class AbstractActivation implements Activation { + + /** The TensorFlow Ops */ + protected Ops tf; + + /** Creates the abstract class for an AbstractActivation */ + protected AbstractActivation() {} + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + protected Ops getTF() { + return this.tf; + } + + /** + * Sets the TensorFlow Ops + * + * @param tf the TensorFlow Ops + */ + protected void setTF(Ops tf) { + this.tf = tf; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java index e1482a51a8a..f73c6678ab3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,50 +19,19 @@ import org.tensorflow.types.family.TNumber; /** - * Abstract base class for Activations + * Interface for Activations * - *

Note: The {@link #tf} attribute must be set prior to invoking the call method. See - * {@link #setTF(Ops)} and the constructor {@link #Activation(Ops)}. - * - * @param the data type of the activation + * @param the data type of the input and the result */ -public abstract class Activation { - - /** The TensorFlow Ops */ - protected Ops tf; - - /** - * Creates the abstract class for an Activation - * - * @param tf the TensorFlow Ops - */ - protected Activation(Ops tf) { - this.tf = tf; - } - - /** - * Sets the TensorFlow Ops - * - * @param tf the TensorFlow Ops - */ - protected void setTF(Ops tf) { - this.tf = tf; - } - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - protected Ops getTF() { - return this.tf; - } +@FunctionalInterface +public interface Activation { /** * Gets the calculation operation for the activation. * + * @param tf the TensorFlow Ops * @param input the input tensor * @return The operand for the activation */ - public abstract Operand call(Operand input); + Operand call(Ops tf, Operand input); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java index 2f2f16f2752..919a947a127 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java @@ -19,6 +19,8 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Exponential linear unit. * @@ -44,53 +46,41 @@ * Operand<TFloat32> result = elu.call(input); * * - * @param the data type of the activation * @see Clevert et al, 2016, Fast and Accurate Deep * Network Learning by Exponential Linear Units (ELUs) */ -public class ELU extends Activation { +public class ELU extends AbstractActivation { private static final double ALPHA_DEFAULT = 1.0; /** A scalar, slope of negative section. */ private final double alpha; - /** - * Creates a new ELU with alpha={@link #ALPHA_DEFAULT}. - * - * @param tf the TensorFlow Ops - */ - public ELU(Ops tf) { - this(tf, ALPHA_DEFAULT); + /** Creates a new ELU with alpha={@link #ALPHA_DEFAULT}. */ + public ELU() { + this(ALPHA_DEFAULT); } /** * Creates a new ELU * - * @param tf the TensorFlow Ops * @param alpha A scalar, slope of negative section. It controls the value to which an ELU * saturates for negative net inputs. */ - public ELU(Ops tf, double alpha) { - super(tf); + public ELU(double alpha) { + super(); this.alpha = alpha; } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { - + public Operand call(Ops tf, Operand input) { Operand result = tf.nn.elu(input); if (alpha == 1.0) return result; else { Class inputType = input.type(); - Operand y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), inputType)); - Operand cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), inputType)); + Operand y = tf.math.mul(result, cast(tf, tf.constant(alpha), inputType)); + Operand cond = tf.math.greater(result, cast(tf, tf.constant(0), inputType)); return tf.select(cond, result, y); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java index d5fdff36c61..8398ada6362 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java @@ -30,28 +30,17 @@ * Operand<TFloat32> result = exp.call(input); * // result is [0.04978707f, 0.36787945f, 1.f, 2.7182817f, 20.085537f] * - * - * @param the data type of the activation */ -public class Exponential extends Activation { +public class Exponential extends AbstractActivation { - /** - * Creates an Exponential activation. - * - * @param tf the TensorFlow Ops - */ - public Exponential(Ops tf) { - super(tf); + /** Creates an Exponential activation. */ + public Exponential() { + super(); } - /** - * Calculates the Exponential activation. - * - * @param input the input tensor - * @return an Operand for the exponential activation: exp(x). - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.math.exp(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java index 0b7cf573b8e..fac4d14eca5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java @@ -18,6 +18,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Hard sigmoid activation. * @@ -40,34 +42,23 @@ * Operand<TFloat32> result = hardSigmoid.call(input); * // result is [0.f , 0.3f, 0.5f, 0.7f, 1.f] * - * - * @param the data type of the result */ -public class HardSigmoid extends Activation { +public class HardSigmoid extends AbstractActivation { - /** - * Creates Hard sigmoid activation. - * - * @param tf the TensorFlow Ops - */ - public HardSigmoid(Ops tf) { - super(tf); + /** Creates Hard sigmoid activation. */ + public HardSigmoid() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { Class inputType = input.type(); - Operand point2 = tf.dtypes.cast(tf.constant(0.2), inputType); - Operand point5 = tf.dtypes.cast(tf.constant(0.5), inputType); + Operand point2 = cast(tf, tf.constant(0.2), inputType); + Operand point5 = cast(tf, tf.constant(0.5), inputType); Operand x = tf.math.add(tf.math.mul(input, point2), point5); return tf.clipByValue( - x, tf.dtypes.cast(tf.constant(0), inputType), tf.dtypes.cast(tf.constant(1), inputType)); + x, cast(tf, tf.constant(0), inputType), cast(tf, tf.constant(1), inputType)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java index d907397995d..d1a5eede616 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java @@ -19,9 +19,9 @@ import org.tensorflow.types.family.TNumber; /** - * Linear activation function (pass-through). + * Linear activation function (pass-through). * - *

The linear activation returns its input. It is also known as the Identity activation function.

+ *

The linear activation returns its input. It is also known as the Identity activation function. * *

For example: * @@ -33,20 +33,16 @@ * // result is [-3.0f,-1.0f, 0.0f,1.0f,3.0f] * */ -public class Linear extends Activation { +public class Linear extends AbstractActivation { - /** - * Creates a linear activation. - * - * @param tf the TensorFlow Ops - */ - public Linear(Ops tf) { - super(tf); + /** Creates a linear activation. */ + public Linear() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return input; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java index aef6ebf2992..c966e5d9ddd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java @@ -20,6 +20,8 @@ import org.tensorflow.op.nn.LeakyRelu; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Rectified Linear Unit(ReLU) activation. * @@ -58,7 +60,7 @@ * * @param the data type of the result */ -public class ReLU extends Activation { +public class ReLU extends AbstractActivation { public static final float ALPHA_DEFAULT = 0.0f; public static final float MAX_VALUE_DEFAULT = Float.NaN; @@ -71,24 +73,21 @@ public class ReLU extends Activation { /** * Creates a new ReLU with alpha={@link #ALPHA_DEFAULT}, maxValue={@link #MAX_VALUE_DEFAULT}, * threshold={@link #THRESHOLD_DEFAULT}, - * - * @param tf the TensorFlow Ops */ - public ReLU(Ops tf) { - this(tf, ALPHA_DEFAULT, MAX_VALUE_DEFAULT, THRESHOLD_DEFAULT); + public ReLU() { + this(ALPHA_DEFAULT, MAX_VALUE_DEFAULT, THRESHOLD_DEFAULT); } /** * Creates a new ReLU * - * @param tf the TensorFlow Ops * @param alpha governs the slope for values lower than the threshold. * @param maxValue sets the saturation threshold (the largest value the function will return). * @param threshold the threshold value of the activation function below which values will be * damped or set to zero. */ - public ReLU(Ops tf, float alpha, float maxValue, float threshold) { - super(tf); + public ReLU(float alpha, float maxValue, float threshold) { + super(); this.alpha = alpha; this.maxValue = maxValue; this.threshold = threshold; @@ -96,7 +95,7 @@ public ReLU(Ops tf, float alpha, float maxValue, float threshold) { /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { Class inputType = input.type(); boolean clipMax = !Float.isNaN(maxValue); @@ -108,7 +107,7 @@ public Operand call(Operand input) { if (threshold != 0) { negativePart = tf.nn.relu( - tf.math.add(tf.math.neg(input), tf.dtypes.cast(tf.constant(threshold), inputType))); + tf.math.add(tf.math.neg(input), cast(tf, tf.constant(threshold), inputType))); } else { negativePart = tf.nn.relu(tf.math.neg(input)); } @@ -117,8 +116,8 @@ public Operand call(Operand input) { Operand lInput; if (threshold != 0) { // computes input for input > threshold else 0 - Greater greater = tf.math.greater(input, tf.dtypes.cast(tf.constant(threshold), inputType)); - lInput = tf.math.mul(input, tf.dtypes.cast(greater, inputType)); + Greater greater = tf.math.greater(input, cast(tf, tf.constant(threshold), inputType)); + lInput = tf.math.mul(input, cast(tf, greater, inputType)); } else if (maxValue == 6) { // if no threshold, then can use nn.relu6 native TF op for performance lInput = tf.nn.relu6(input); @@ -127,15 +126,14 @@ public Operand call(Operand input) { lInput = tf.nn.relu(input); } if (clipMax) { - Operand lmaxValue = tf.dtypes.cast(tf.constant(maxValue), inputType); - Operand zero = tf.dtypes.cast(tf.constant(0), inputType); + Operand lmaxValue = cast(tf, tf.constant(maxValue), inputType); + Operand zero = cast(tf, tf.constant(0), inputType); lInput = tf.clipByValue(lInput, zero, lmaxValue); } if (alpha != 0.) { lInput = - tf.math.sub( - lInput, tf.math.mul(tf.dtypes.cast(tf.constant(alpha), inputType), negativePart)); + tf.math.sub(lInput, tf.math.mul(cast(tf, tf.constant(alpha), inputType), negativePart)); } return lInput; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java index f24731049fb..a28052486e5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java @@ -45,25 +45,16 @@ * @param the data type of the activation * @see Klambauer et al., 2017 */ -public class SELU extends Activation { +public class SELU extends AbstractActivation { - /** - * Creates a Scaled Exponential Linear Unit (SELU) activation. - * - * @param tf the TensorFlow Ops - */ - public SELU(Ops tf) { - super(tf); + /** Creates a Scaled Exponential Linear Unit (SELU) activation. */ + public SELU() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.nn.selu(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java index 5d507b38483..02b2daae4d6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java @@ -41,25 +41,16 @@ * * @param the data type of the activation */ -public class Sigmoid extends Activation { +public class Sigmoid extends AbstractActivation { - /** - * Creates a Sigmoid activation. - * - * @param tf the TensorFlow Ops - */ - public Sigmoid(Ops tf) { - super(tf); + /** Creates a Sigmoid activation. */ + public Sigmoid() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.math.sigmoid(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java index 154e1ecc84a..3aa67a179ad 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java @@ -38,7 +38,7 @@ * * @param the data type of the activation */ -public class Softmax extends Activation { +public class Softmax extends AbstractActivation { private static final int AXIS_DEFAULT = -1; @@ -47,32 +47,24 @@ public class Softmax extends Activation { /** * Creates a softmax activation where the default axis is {@link #AXIS_DEFAULT} which indicates * the last dimension. - * - * @param tf the TensorFlow Ops */ - public Softmax(Ops tf) { - this(tf, AXIS_DEFAULT); + public Softmax() { + this(AXIS_DEFAULT); } /** * Creates a Softmax activation * - * @param tf the TensorFlow Ops * @param axis The dimension softmax would be performed on. */ - public Softmax(Ops tf, int axis) { - super(tf); + public Softmax(int axis) { + super(); this.axis = axis; } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { Shape shape = input.shape(); int numDimensions = shape.numDimensions(); if (numDimensions == 2) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java index 65a183ea047..8533de7852c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java @@ -32,25 +32,16 @@ * // 1.3132616e+00f, 2.0000000e+01f] * */ -public class Softplus extends Activation { +public class Softplus extends AbstractActivation { - /** - * Creates a Softplus activation function. - * - * @param tf the TensorFlow Ops - */ - public Softplus(Ops tf) { - super(tf); + /** Creates a Softplus activation function. */ + public Softplus() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.math.softplus(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java index 1f691e71862..249fa6077cd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java @@ -33,25 +33,16 @@ * * @param the data type of the activation */ -public class Softsign extends Activation { +public class Softsign extends AbstractActivation { - /** - * Creates a Softsign activation. - * - * @param tf the TensorFlow Ops - */ - public Softsign(Ops tf) { - super(tf); + /** Creates a Softsign activation. */ + public Softsign() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.nn.softsign(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java index d9f73a422d5..5007dd34555 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java @@ -40,7 +40,7 @@ * @param the data type of the activation * @see Ramachandran et al., 2017 */ -public class Swish extends Activation { +public class Swish extends AbstractActivation { /** * Creates a Swish activation, swish(x) = x * sigmoid(x). @@ -48,17 +48,14 @@ public class Swish extends Activation { *

Swish activation function which returns x*sigmoid(x). It is a smooth, * non-monotonic function that consistently matches or outperforms ReLU on deep networks, it is * unbounded above and bounded below. - * - * @param tf the TensorFlow Ops */ - public Swish(Ops tf) { - super(tf); + public Swish() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand input) { - + public Operand call(Ops tf, Operand input) { // TODO Python Keras returns a "grad", which is an optimization not implemented in Java. return tf.math.mul(input, tf.math.sigmoid(input)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java index 4fe02eed048..37d4d811a0d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java @@ -33,20 +33,16 @@ * * @param the data type of the activation */ -public class Tanh extends Activation { +public class Tanh extends AbstractActivation { - /** - * Creates a Hyperbolic tangent activation. - * - * @param tf the TensorFlow Ops - */ - public Tanh(Ops tf) { - super(tf); + /** Creates a Hyperbolic tangent activation. */ + public Tanh() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.math.tanh(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java new file mode 100644 index 00000000000..266d01620bd --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** Base class for Constraints. AbstractConstraint subclasses impose constraints on weight values */ +public abstract class AbstractConstraint implements Constraint { + + public static final float EPSILON = 1e-7f; + + /** Creates a AbstractConstraint */ + public AbstractConstraint() {} + + /** + * Gets the element-wise square root. + * + * @param tf the TensorFlow Ops + * @param x the input Operand. + * @return the element-wise square root. + * @param The data type for the operand and result. + * @throws IllegalArgumentException if x is null + */ + protected Operand sqrt(Ops tf, Operand x) { + if (x == null) throw new IllegalArgumentException("Operand x must not be null"); + Class type = x.type(); + Operand zero = cast(tf, tf.constant(0), type); + Operand inf = cast(tf, tf.constant(Double.POSITIVE_INFINITY), type); + return tf.math.sqrt(tf.clipByValue(x, zero, inf)); + } + + /** + * Gets the element-wise value clipping. + * + * @param tf the TensorFlow Ops + * @param x the Operand to clip + * @param minValue the minimum value + * @param maxValue the maximum value + * @return the operand with clipped values + * @param The data type for the operand and result. + * @throws IllegalArgumentException if x is null + */ + protected Operand clip( + Ops tf, Operand x, double minValue, double maxValue) { + if (x == null) throw new IllegalArgumentException("Operand x must not be null"); + Class type = x.type(); + + double min = Math.min(minValue, maxValue); + double max = Math.max(minValue, maxValue); + + Operand minValueConstant = cast(tf, tf.constant(min), type); + Operand maxValueConstant = cast(tf, tf.constant(max), type); + return tf.clipByValue(x, minValueConstant, maxValueConstant); + } + + /** + * Calculates the norm of the weights along the axes + * + * @param tf the TensorFlow Ops + * @param weights the weights used to calculate the norms + * @param axes the axes along which to calculate weight norms. + * @param the data type for the weights and the result + * @return the norms + * @throws IllegalArgumentException if weights is null + */ + protected Operand norm(Ops tf, Operand weights, int[] axes) { + if (weights == null) throw new IllegalArgumentException("weights must not be null"); + return sqrt( + tf, + tf.reduceSum(tf.math.square(weights), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE))); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java index 306361959bf..97640b19cf8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,96 +16,16 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - -/** Base class for Constraints. Constraint subclasses impose constraints on weight values */ -public abstract class Constraint { - - public static final float EPSILON = 1e-7f; - - private final Ops tf; - - /** - * Creates a Constraint - * - * @param tf the TensorFlow Ops - */ - public Constraint(Ops tf) { - this.tf = tf; - } - +public interface Constraint { /** * Applies the constraint against the provided weights * + * @param tf the TensorFlow Ops * @param weights the weights * @return the constrained weights * @param the data type for weights and results. */ - public abstract Operand call(Operand weights); - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - public Ops getTF() { - return tf; - } - - /** - * Gets the element-wise square root. - * - * @param x the input Operand. - * @return the element-wise square root. - * @param The data type for the operand and result. - * @throws IllegalArgumentException if x is null - */ - protected Operand sqrt(Operand x) { - if (x == null) throw new IllegalArgumentException("Operand x must not be null"); - Class type = x.type(); - Operand zero = cast(tf, tf.constant(0), type); - Operand inf = cast(tf, tf.constant(Double.POSITIVE_INFINITY), type); - return tf.math.sqrt(tf.clipByValue(x, zero, inf)); - } - - /** - * Gets the element-wise value clipping. - * - * @param x the Operand to clip - * @param minValue the minimum value - * @param maxValue the maximum value - * @return the operand with clipped values - * @param The data type for the operand and result. - * @throws IllegalArgumentException if x is null - */ - protected Operand clip(Operand x, double minValue, double maxValue) { - if (x == null) throw new IllegalArgumentException("Operand x must not be null"); - Ops tf = getTF(); - Class type = x.type(); - - double min = Math.min(minValue, maxValue); - double max = Math.max(minValue, maxValue); - - Operand minValueConstant = cast(tf, tf.constant(min), type); - Operand maxValueConstant = cast(tf, tf.constant(max), type); - return tf.clipByValue(x, minValueConstant, maxValueConstant); - } - - /** - * Calculates the norm of the weights along the axes - * - * @param weights the weights used to calculate the norms - * @param axes the axes along which to calculate weight norms. - * @param the data type for the weights and the result - * @return the norms - * @throws IllegalArgumentException if weights is null - */ - protected Operand norm(Operand weights, int[] axes) { - if (weights == null) throw new IllegalArgumentException("weights must not be null"); - return sqrt( - tf.reduceSum(tf.math.square(weights), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE))); - } + Operand call(Ops tf, Operand weights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java index 1dae117b113..b9f082f54de 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java @@ -24,7 +24,7 @@ * Constrains the weights incident to each hidden unit to have a norm less than or equal to a * desired value. */ -public class MaxNorm extends Constraint { +public class MaxNorm extends AbstractConstraint { public static final double MAX_VALUE_DEFAULT = 2.0; public static final int AXIS_DEFAULT = 0; @@ -36,54 +36,48 @@ public class MaxNorm extends Constraint { /** * Create a MaxNorm constraint using {@link #MAX_VALUE_DEFAULT} for the max value and {@link * #AXIS_DEFAULT} for the axis. - * - * @param tf the TensorFlow Ops */ - public MaxNorm(Ops tf) { - this(tf, MAX_VALUE_DEFAULT, AXIS_DEFAULT); + public MaxNorm() { + this(MAX_VALUE_DEFAULT, AXIS_DEFAULT); } /** * Create a MaxNorm constraint using {@link #AXIS_DEFAULT} for the axis. * - * @param tf the TensorFlow Ops * @param maxValue the maximum norm for the incoming weights. */ - public MaxNorm(Ops tf, double maxValue) { - this(tf, maxValue, AXIS_DEFAULT); + public MaxNorm(double maxValue) { + this(maxValue, AXIS_DEFAULT); } /** * Create a MaxNorm constraint * - * @param tf the TensorFlow Ops * @param maxValue the maximum norm for the incoming weights. * @param axis axis along which to calculate weight norms. */ - public MaxNorm(Ops tf, double maxValue, int axis) { - this(tf, maxValue, new int[] {axis}); + public MaxNorm(double maxValue, int axis) { + this(maxValue, new int[] {axis}); } /** * Create a MaxNorm constraint * - * @param tf the TensorFlow Ops * @param maxValue the maximum norm for the incoming weights. * @param axes axes along which to calculate weight norms. */ - public MaxNorm(Ops tf, double maxValue, int[] axes) { - super(tf); + public MaxNorm(double maxValue, int[] axes) { + super(); this.maxValue = maxValue; this.axes = axes; } /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { - Ops tf = getTF(); + public Operand call(Ops tf, Operand weights) { Class type = weights.type(); - Operand norms = norm(weights, getAxes()); - Operand desired = clip(norms, 0f, this.getMaxValue()); + Operand norms = norm(tf, weights, getAxes()); + Operand desired = clip(tf, norms, 0f, this.getMaxValue()); return tf.math.mul( weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms))); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java index 04b21572e55..97e86d7693f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java @@ -21,7 +21,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** Constrains the weights to have the norm between a lower bound and an upper bound. */ -public class MinMaxNorm extends Constraint { +public class MinMaxNorm extends AbstractConstraint { public static final double MIN_VALUE_DEFAULT = 0.0; public static final double MAX_VALUE_DEFAULT = 1.0; public static final double RATE_DEFAULT = 1.0; @@ -47,48 +47,43 @@ public class MinMaxNorm extends Constraint { * Create a MinMaxNorm constraint using {@link #MIN_VALUE_DEFAULT} for the min value, {@link * #MAX_VALUE_DEFAULT} for the max value, {@link #RATE_DEFAULT} for the rate and {@link * #AXIS_DEFAULT} for the axis - * - * @param tf the TensorFlow Ops */ - public MinMaxNorm(Ops tf) { - this(tf, MIN_VALUE_DEFAULT, MAX_VALUE_DEFAULT, RATE_DEFAULT, AXIS_DEFAULT); + public MinMaxNorm() { + this(MIN_VALUE_DEFAULT, MAX_VALUE_DEFAULT, RATE_DEFAULT, AXIS_DEFAULT); } /** * Create a MinMaxNorm constraint using {@link #RATE_DEFAULT} for the rate and {@link * #AXIS_DEFAULT} for the axis * - * @param tf the TensorFlow Ops * @param minValue the minimum norm for the incoming weights. * @param maxValue the maximum norm for the incoming weights. */ - public MinMaxNorm(Ops tf, double minValue, double maxValue) { - this(tf, minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT); + public MinMaxNorm(double minValue, double maxValue) { + this(minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT); } /** * Create a MinMaxNorm constraint * - * @param tf the TensorFlow Ops * @param minValue the minimum norm for the incoming weights. * @param maxValue the maximum norm for the incoming weights. * @param rate the rate for enforcing the constraint. * @param axis integer, axis along which to calculate weight norms. */ - public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int axis) { - this(tf, minValue, maxValue, rate, new int[] {axis}); + public MinMaxNorm(double minValue, double maxValue, double rate, int axis) { + this(minValue, maxValue, rate, new int[] {axis}); } /** * Create a MinMaxNorm constraint * - * @param tf the TensorFlow Ops * @param minValue the minimum norm for the incoming weights. * @param maxValue the maximum norm for the incoming weights. * @param rate the rate for enforcing the constraint. * @param axes integer, axis along which to calculate weight norms. */ - public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int[] axes) { - super(tf); + public MinMaxNorm(double minValue, double maxValue, double rate, int[] axes) { + super(); this.minValue = minValue; this.maxValue = maxValue; this.rate = rate; @@ -97,15 +92,14 @@ public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int[] a /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { + public Operand call(Ops tf, Operand weights) { Class type = weights.type(); - Ops tf = getTF(); - Operand norms = norm(weights, getAxes()); + Operand norms = norm(tf, weights, getAxes()); Operand desired = tf.math.add( tf.math.mul( tf.dtypes.cast(tf.constant(this.getRate()), type), - clip(norms, this.getMinValue(), this.getMaxValue())), + clip(tf, norms, this.getMinValue(), this.getMaxValue())), tf.math.mul( tf.math.sub( tf.dtypes.cast(tf.constant(1), type), diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java index 0194b2fadb6..6a5677983fa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java @@ -19,21 +19,16 @@ import org.tensorflow.types.family.TNumber; /** Constrains the weights to be non-negative. */ -public class NonNeg extends Constraint { +public class NonNeg extends AbstractConstraint { - /** - * Create a NonNeg constraint - * - * @param tf the TensorFlow Ops - */ - public NonNeg(Ops tf) { - super(tf); + /** Create a NonNeg constraint */ + public NonNeg() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { - Ops tf = getTF(); + public Operand call(Ops tf, Operand weights) { Class type = weights.type(); return tf.math.mul( weights, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java index 70bb1a59785..fdd71945229 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java @@ -21,50 +21,43 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** Constrains the weights to have unit norm. */ -public class UnitNorm extends Constraint { +public class UnitNorm extends AbstractConstraint { public static final int AXIS_DEFAULT = 0; /** integer, axis along which to calculate weight norms. */ private final int[] axes; - /** - * Create a UnitNorm Constraint with the axis set to {@link #AXIS_DEFAULT} - * - * @param tf the TensorFlow Ops - */ - public UnitNorm(Ops tf) { - this(tf, AXIS_DEFAULT); + /** Create a UnitNorm AbstractConstraint with the axis set to {@link #AXIS_DEFAULT} */ + public UnitNorm() { + this(AXIS_DEFAULT); } /** - * Create a UnitNorm Constraint + * Create a UnitNorm AbstractConstraint * - * @param tf the TensorFlow Ops * @param axis axis along which to calculate weight norms. */ - public UnitNorm(Ops tf, int axis) { - this(tf, new int[] {axis}); + public UnitNorm(int axis) { + this(new int[] {axis}); } /** - * Create a UnitNorm Constraint + * Create a UnitNorm AbstractConstraint * - * @param tf the TensorFlow Ops * @param axes axes along which to calculate weight norms. */ - public UnitNorm(Ops tf, int[] axes) { - super(tf); + public UnitNorm(int[] axes) { + super(); this.axes = axes; } /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { + public Operand call(Ops tf, Operand weights) { Class type = weights.type(); - Ops tf = getTF(); return tf.math.div( - weights, tf.math.add(cast(tf, tf.constant(EPSILON), type), norm(weights, getAxes()))); + weights, tf.math.add(cast(tf, tf.constant(EPSILON), type), norm(tf, weights, getAxes()))); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java index 9c1fa9ac287..56e3d310280 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java @@ -14,29 +14,24 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TType; /** Abstract base class for all Initializers */ public abstract class BaseInitializer implements Initializer { - protected final Ops tf; + private final String name; - /** - * Creates an Initializer - * - * @param tf the TensorFlow Ops - */ - protected BaseInitializer(Ops tf) { - this.tf = tf; + /** Creates an Initializer */ + protected BaseInitializer() { + name = getClass().getSimpleName(); } /** - * Gets the TensorFlow Ops + * Gets the name for this initializer * - * @return the TensorFlow Ops + * @return the name for this initializer */ - public Ops getTF() { - return tf; + public String getName() { + return name; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java index 4a2df86d74b..508fb69fd55 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java @@ -21,6 +21,8 @@ import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors with a constant value. * @@ -30,7 +32,7 @@ * Constant<TFloat32> initializer = * new org.tensorflow.framework.initializers.Constant<>(tf, 3f); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The Type for the call operation @@ -45,11 +47,10 @@ public class Constant extends BaseInitializer { /** * Creates an Initializer that generates tensors with a constant value. * - * @param tf the TensorFlow Ops * @param value a long value used for the constant. */ - public Constant(Ops tf, long value) { - super(tf); + public Constant(long value) { + super(); longValue = value; doubleValue = 0; booleanValue = false; @@ -59,11 +60,10 @@ public Constant(Ops tf, long value) { /** * Creates an Initializer that generates tensors with a constant value. * - * @param tf the TensorFlow Ops * @param value a double value used for the constant. */ - public Constant(Ops tf, double value) { - super(tf); + public Constant(double value) { + super(); doubleValue = value; longValue = 0; booleanValue = false; @@ -73,11 +73,10 @@ public Constant(Ops tf, double value) { /** * Creates an Initializer that generates tensors with a constant value. * - * @param tf the TensorFlow Ops * @param value a boolean value used for the constant. */ - public Constant(Ops tf, boolean value) { - super(tf); + public Constant(boolean value) { + super(); booleanValue = value; doubleValue = 0; longValue = 0; @@ -86,17 +85,19 @@ public Constant(Ops tf, boolean value) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { - throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName()); + throw new IllegalArgumentException( + "Tensor type must be numeric or boolean: " + type.getSimpleName()); } switch (valueType) { case LONG: - return tf.fill(dims, tf.dtypes.cast(tf.constant(longValue), type)); + return tf.fill(dims, cast(tf, tf.constant(longValue), type)); case DOUBLE: - return tf.fill(dims, tf.dtypes.cast(tf.constant(doubleValue), type)); + return tf.fill(dims, cast(tf, tf.constant(doubleValue), type)); default: - return tf.fill(dims, tf.dtypes.cast(tf.constant(booleanValue), type)); + return tf.fill(dims, cast(tf, tf.constant(booleanValue), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java index 894bd073758..4a39c3839f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.initializers; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; /** @@ -43,7 +42,7 @@ * new org.tensorflow.framework.initializers.Glorot<>(tf, * Distribution.TRUNCATED_NORMAL, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

Glorot Uniform: @@ -54,12 +53,14 @@ * new org.tensorflow.framework.initializers.Glorot<>(tf, * Distribution.UNIFORM, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

NOTE: + * *

For a GlorotNormal equivalent initializer, use {@link * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. + * *

For a GlorotUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. * @@ -74,13 +75,12 @@ public class Glorot extends VarianceScaling { /** * Creates a Glorot initializer * - * @param tf the TensorFlow Ops * @param distribution The distribution type for the Glorot initializer. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. * @see VarianceScaling.Distribution */ - public Glorot(Ops tf, Distribution distribution, long seed) { - super(tf, SCALE, Mode.FAN_AVG, distribution, seed); + public Glorot(Distribution distribution, long seed) { + super(SCALE, Mode.FAN_AVG, distribution, seed); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java index 3a91b72b0d0..4a9fa8a7849 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; /** @@ -38,7 +37,7 @@ * new org.tensorflow.framework.initializers.He<>(tf, * Distribution.TRUNCATED_NORMAL, seed);); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

He Uniform: @@ -49,14 +48,16 @@ * new org.tensorflow.framework.initializers.He<>(tf, * Distribution.UNIFORM, seed);); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

NOTE: + * *

For an HeNormal equivalent initializer, use {@link * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. - *

For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} - * for the distribution parameter. + * + *

For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} for + * the distribution parameter. * * @param The TType for the call operation * @see extends VarianceScaling { /** * Creates an He Initializer * - * @param tf the TensorFlow Ops * @param distribution The distribution type for the He initializer. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. * @see VarianceScaling.Distribution */ - public He(Ops tf, Distribution distribution, long seed) { - super(tf, SCALE, Mode.FAN_IN, distribution, seed); + public He(Distribution distribution, long seed) { + super(SCALE, Mode.FAN_IN, distribution, seed); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java index f672c9f1e85..34a77520406 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java @@ -21,6 +21,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates the identity matrix. * @@ -32,40 +34,34 @@ * Identity<TFloat32> initializer = * new org.tensorflow.framework.initializers.Identity<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation */ public class Identity extends BaseInitializer { public static final double GAIN_DEFAULT = 1.0; - private final double gain; - /** - * Creates an Initializer that generates the identity matrix. - * - * @param tf the TensorFlow Ops - */ - public Identity(Ops tf) { - super(tf); - this.gain = GAIN_DEFAULT; + /** Creates an Initializer that generates the identity matrix. */ + public Identity() { + this(GAIN_DEFAULT); } /** * Creates an Initializer that generates the identity matrix. * - * @param tf the TensorFlow Ops * @param gain the gain to be applied to the Identity Matrix */ - public Identity(Ops tf, double gain) { - super(tf); + public Identity(double gain) { + super(); this.gain = gain; } /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + Shape shape = ShapeUtils.toShape(tf.scope(), dims); if (shape.numDimensions() != 2) { throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions()); @@ -75,9 +71,9 @@ public Operand call(Operand dims, Class type) { Shape diagShape = Shape.of(diagSize); Operand op; - Operand zero = tf.dtypes.cast(tf.constant(0), type); + Operand zero = cast(tf, tf.constant(0), type); Operand diagOnes = - tf.fill(tf.constant(diagShape.asArray()), tf.dtypes.cast(tf.constant(1.0), type)); + tf.fill(tf.constant(diagShape.asArray()), cast(tf, tf.constant(1.0), type)); if (isSquare) { op = tf.linalg.matrixDiag( @@ -91,6 +87,6 @@ public Operand call(Operand dims, Class type) { op = tf.linalg.matrixSetDiag(zeroMatrix, diagOnes, tf.constant(0)); } - return tf.math.mul(op, tf.dtypes.cast(tf.constant(gain), type)); + return tf.math.mul(op, cast(tf, tf.constant(gain), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java index 4beb218783b..d6593b770e2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.initializers; import org.tensorflow.Operand; +import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -23,14 +24,18 @@ * * @param The data Type for initializer operation */ +@FunctionalInterface public interface Initializer { /** * Generates the operation used to perform the initialization. * + * @param tf the TensorFlow Ops * @param dims the shape dimensions * @param type the type of tensor + * @throws IllegalStateException if the object has not been initialized with the TensorFlow + * Platform. * @return An operand for the initialization. */ - Operand call(Operand dims, Class type); + Operand call(Ops tf, Operand dims, Class type); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java index 38e68ef688b..364c5fb9285 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; /** @@ -27,7 +26,7 @@ * stddev = sqrt(1 / fanIn) where fanIn is the number of input units in the * weight tensor. * - *

If the distribution is UNIFORM, itraws samples from a uniform distribution within + *

If the distribution is UNIFORM, it draws samples from a uniform distribution within * [-limit, limit], where limit = Math.sqrt(3 / fanIn) (fanIn is * the number of input units in the weight tensor) * @@ -41,7 +40,7 @@ * new org.tensorflow.framework.initializers.LeCunNormal<>(tf, * Distribution.TRUNCATED_NORMAL, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

LeCun Uniform: @@ -52,14 +51,15 @@ * new org.tensorflow.framework.initializers.LeCunNormal<>(tf, * Distribution.UNIFORM, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * * * *

NOTE: * * - *

For a LeCunNormal equivalent initializer, use {@link VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. * + *

For a LeCunNormal equivalent initializer, use {@link + * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. * * *

For a LeCunUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * * for the distribution parameter. * @@ -79,12 +79,11 @@ public class LeCun extends VarianceScaling { /** * Creates a LeCunNormal Initializer * - * @param tf the TensorFlow Ops * @param distribution The distribution type for the Glorot initializer. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public LeCun(Ops tf, Distribution distribution, long seed) { - super(tf, 1.0, Mode.FAN_IN, distribution, seed); + public LeCun(Distribution distribution, long seed) { + super(1.0, Mode.FAN_IN, distribution, seed); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java index b8eb0c418e9..6e818d30bd7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java @@ -21,6 +21,8 @@ import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors initialized to 1. * @@ -30,7 +32,7 @@ * Ones<TFloat32> initializer = * new org.tensorflow.framework.initializers.Ones<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -46,21 +48,21 @@ public class Ones extends BaseInitializer { * Ones<TFloat32> initializer = * new org.tensorflow.framework.initializers.Ones<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param tf the TensorFlow Ops */ - public Ones(Ops tf) { - super(tf); + public Ones() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { - throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName()); + throw new IllegalArgumentException( + "Tensor type must be numeric or boolean: " + type.getSimpleName()); } - return tf.fill(dims, tf.dtypes.cast(tf.constant(1.0), type)); + return tf.fill(dims, cast(tf, tf.constant(1), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java index a5b466e118e..519d0cd042e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java @@ -23,6 +23,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates an orthogonal matrix. * @@ -42,7 +44,7 @@ * Orthogonal<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.Orthogonal<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -57,31 +59,30 @@ public class Orthogonal extends BaseInitializer { /** * Creates an Orthogonal Initializer using {@link #GAIN_DEFAULT} for the gain. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public Orthogonal(Ops tf, long seed) { - this(tf, GAIN_DEFAULT, seed); + public Orthogonal(long seed) { + this(GAIN_DEFAULT, seed); } /** * Creates an Orthogonal Initializer * - * @param tf the TensorFlow Ops * @param gain the gain to be applied to the Matrix. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public Orthogonal(Ops tf, double gain, long seed) { - super(tf); + public Orthogonal(double gain, long seed) { + super(); this.gain = gain; this.seed = seed; } /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + Shape dimsShape = ShapeUtils.toShape(tf.scope(), dims); if (dimsShape.numDimensions() < 2) { throw new IllegalArgumentException( @@ -101,10 +102,10 @@ public Operand call(Operand dims, Class type) { Output qo = qrOp.q(); Output ro = qrOp.r(); Operand diagOp = - tf.linalg.matrixDiagPart(ro, tf.constant(0), tf.dtypes.cast(tf.constant(0), type)); + tf.linalg.matrixDiagPart(ro, tf.constant(0), cast(tf, tf.constant(0), type)); Operand qop = tf.math.mul(qo, tf.math.sign(diagOp)); if (numRows < numCols) qop = tf.linalg.transpose(qop, null); - return tf.math.mul(qop, tf.dtypes.cast(tf.constant(this.gain), type)); + return tf.math.mul(qop, cast(tf, tf.constant(this.gain), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java index 38ab194a56b..9a52a641416 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java @@ -19,6 +19,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors with a normal distribution. * @@ -29,7 +31,7 @@ * RandomNormal<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.RandomNormal<>(tf, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -47,37 +49,34 @@ public class RandomNormal extends BaseInitializer { * Creates the RandomUniform initializer using {@link #MEAN_DEFAULT} for the mean and {@link * #STDDEV_DEFAULT} for the standard deviation. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomNormal(Ops tf, long seed) { - this(tf, MEAN_DEFAULT, STDDEV_DEFAULT, seed); + public RandomNormal(long seed) { + this(MEAN_DEFAULT, STDDEV_DEFAULT, seed); } /** * Creates the RandomUniform initializer using {@link #STDDEV_DEFAULT} for the standard deviation. * - * @param tf the TensorFlow Ops * @param mean Mean of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomNormal(Ops tf, double mean, long seed) { - this(tf, mean, STDDEV_DEFAULT, seed); + public RandomNormal(double mean, long seed) { + this(mean, STDDEV_DEFAULT, seed); } /** * creates the RandomUniform initializer * - * @param tf the TensorFlow Ops * @param mean Mean of the random values to generate. * @param stddev Standard deviation of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomNormal(Ops tf, double mean, double stddev, long seed) { - super(tf); + public RandomNormal(double mean, double stddev, long seed) { + super(); this.mean = mean; this.stddev = stddev; this.seed = seed; @@ -85,10 +84,11 @@ public RandomNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + long[] seeds = {seed, 0}; Operand distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); - Operand op = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.stddev), type)); - return tf.math.add(op, tf.dtypes.cast(tf.constant(mean), type)); + Operand op = tf.math.mul(distOp, cast(tf, tf.constant(this.stddev), type)); + return tf.math.add(op, cast(tf, tf.constant(mean), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java index 787af15f709..7288024f5b8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java @@ -21,6 +21,8 @@ import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors with a uniform distribution. * @@ -31,7 +33,7 @@ * RandomUniform<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.RandomUniform<>(tf, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -46,28 +48,26 @@ public class RandomUniform extends BaseInitializer { private final long seed; /** - * Creates a RandomUniform initializer using {@link #MINVAL_DEFAULT} for the minval and - * {@link #MAXVAL_DEFAULT} for the maxval + * Creates a RandomUniform initializer using {@link #MINVAL_DEFAULT} for the minval and {@link + * #MAXVAL_DEFAULT} for the maxval * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomUniform(Ops tf, long seed) { - this(tf, MINVAL_DEFAULT, MAXVAL_DEFAULT, seed); + public RandomUniform(long seed) { + this(MINVAL_DEFAULT, MAXVAL_DEFAULT, seed); } /** * Creates a RandomUniform initializer * - * @param tf the TensorFlow Ops * @param minval Lower bound of the range of random values to generate (inclusive). * @param maxval Upper bound of the range of random values to generate (exclusive). * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomUniform(Ops tf, double minval, double maxval, long seed) { - super(tf); + public RandomUniform(double minval, double maxval, long seed) { + super(); this.minval = minval; this.maxval = maxval; this.seed = seed; @@ -75,26 +75,27 @@ public RandomUniform(Ops tf, double minval, double maxval, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + Operand distOp; if (TIntegral.class.isAssignableFrom(type)) { RandomUniformInt.Options options = RandomUniformInt.seed(this.seed); distOp = tf.random.randomUniformInt( dims, - tf.dtypes.cast(tf.constant(this.minval), type), - tf.dtypes.cast(tf.constant(this.maxval), type), + cast(tf, tf.constant(this.minval), type), + cast(tf, tf.constant(this.maxval), type), options); } else { long[] seeds = {seed, 0}; distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); if (this.minval == 0) { if (this.maxval != 1.0) { - distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval), type)); + distOp = tf.math.mul(distOp, cast(tf, tf.constant(this.maxval), type)); } } else { - distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval - this.minval), type)); - distOp = tf.math.add(distOp, tf.dtypes.cast(tf.constant(this.minval), type)); + distOp = tf.math.mul(distOp, cast(tf, tf.constant(this.maxval - this.minval), type)); + distOp = tf.math.add(distOp, cast(tf, tf.constant(this.minval), type)); } } return distOp; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java index d3cfec26338..8069d5d9c7d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java @@ -19,6 +19,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates a truncated normal distribution. * @@ -29,7 +31,7 @@ * TruncatedNormal<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.TruncatedNormal<>(tf, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -47,25 +49,23 @@ public class TruncatedNormal extends BaseInitializer { * Creates a TruncatedNormal Initializer using {@link #MEAN_DEFAULT} for the mean and {@link * #STDDEV_DEFAULT} for the standard deviation. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public TruncatedNormal(Ops tf, long seed) { - this(tf, MEAN_DEFAULT, STDDEV_DEFAULT, seed); + public TruncatedNormal(long seed) { + this(MEAN_DEFAULT, STDDEV_DEFAULT, seed); } /** * Creates a TruncatedNormal Initializer. * - * @param tf the TensorFlow Ops * @param mean Mean of the random values to generate. * @param stddev Standard deviation of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public TruncatedNormal(Ops tf, double mean, double stddev, long seed) { - super(tf); + public TruncatedNormal(double mean, double stddev, long seed) { + super(); this.mean = mean; this.stddev = stddev; this.seed = seed; @@ -73,11 +73,12 @@ public TruncatedNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { - long[] seeds = {seed,0}; + public Operand call(Ops tf, Operand dims, Class type) { + + long[] seeds = {seed, 0}; Operand distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); return tf.math.add( - tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)), - tf.dtypes.cast(tf.constant(mean), type)); + tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)), + cast(tf, tf.constant(mean), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java index 5d951450505..a04e4a9a378 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java @@ -21,11 +21,13 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer capable of adapting its scale to the shape of weights tensors. * - *

With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from - * a truncated/untruncated normal distribution with a mean of zero and a standard deviation (after + *

With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from a + * truncated/untruncated normal distribution with a mean of zero and a standard deviation (after * truncation, if used) stddev = Math.sqrt(scale / n), where n is: * *

    @@ -46,7 +48,7 @@ * new org.tensorflow.framework.initializers.VarianceScaling<>( * tf, scale, Mode.FAN_IN, Distribution.UNIFORM, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -64,28 +66,25 @@ public class VarianceScaling extends BaseInitializer { private final Distribution distribution; private final long seed; - /** * Creates a VarianceScaling Initializer * - * @param tf the TensorFlow Ops * @param seed sed to create random seeds. */ - public VarianceScaling(Ops tf, long seed) { - this(tf, SCALE_DEFAULT, MODE_DEFAULT, DISTRIBUTION_DEFAULT, seed); + public VarianceScaling(long seed) { + this(SCALE_DEFAULT, MODE_DEFAULT, DISTRIBUTION_DEFAULT, seed); } /** * Creates a VarianceScaling Initializer * - * @param tf the TensorFlow Ops * @param scale Scaling factor (positive float). * @param mode the mode for the variance * @param distribution Random distribution to use. * @param seed Used to create random seeds. */ - public VarianceScaling(Ops tf, double scale, Mode mode, Distribution distribution, long seed) { - super(tf); + public VarianceScaling(double scale, Mode mode, Distribution distribution, long seed) { + super(); if (scale <= 0.0) { throw new IllegalArgumentException("scale must be greater than 0, got " + scale); } @@ -97,8 +96,9 @@ public VarianceScaling(Ops tf, double scale, Mode mode, Distribution distributio /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { - Shape shape = ShapeUtils.toShape(this.tf.scope(), dims); + public Operand call(Ops tf, Operand dims, Class type) { + + Shape shape = ShapeUtils.toShape(tf.scope(), dims); double lscale = this.scale; double[] fans /* fanIn, fanOut */ = computeFans(shape); switch (mode) { @@ -119,18 +119,18 @@ public Operand call(Operand dims, Class type) { switch (distribution) { case TRUNCATED_NORMAL: distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); - stddev = Math.sqrt(lscale) / .87962566103423978; - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); + stddev = Math.sqrt(lscale) / 0.87962566103423978; + mulOp = tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)); break; case NORMAL: distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); stddev = Math.sqrt(lscale); - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); + mulOp = tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)); break; case UNIFORM: distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); stddev = Math.sqrt(3.0 * lscale); - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); + mulOp = tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)); break; } return mulOp; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java index 4298493ac44..f581d247deb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java @@ -28,24 +28,21 @@ * Zeros<TFloat32> initializer = * new org.tensorflow.framework.initializers.Zeros<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation */ public class Zeros extends BaseInitializer { - /** - * Creates an Initializer that sets all values to one. - * - * @param tf the TensorFlow Ops - */ - public Zeros(Ops tf) { - super(tf); + /** Creates an Initializer that sets all values to one. */ + public Zeros() { + super(); } @Override - public Operand call(Operand dims, Class dtype) { - return tf.zeros(dims, dtype); + public Operand call(Ops tf, Operand dims, Class type) { + + return tf.zeros(dims, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index 3417c07372a..0c7c6abf8af 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -35,7 +36,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * BinaryCrossentropy bce = new BinaryCrossentropy(tf); - * Operand<TFloat32> result = bce.call(labels, predictions); + * Operand<TFloat32> result = bce.call(Ops tf, labels, predictions); * // produces 0.815 * * @@ -43,7 +44,7 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
    - *    Operand<TFloat32> result = bce.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = bce.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 0.458f
      * 
    * @@ -51,7 +52,7 @@ * *
      *    BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = bce.call(labels, predictions);
    + *    Operand<TFloat32> result = bce.call(Ops tf, labels, predictions);
      *    // produces 1.630f
      * 
    * @@ -59,11 +60,11 @@ * *
      *    BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = bce.call(labels, predictions);
    + *    Operand<TFloat32> result = bce.call(Ops tf, labels, predictions);
      *    // produces [0.916f, 0.714f]
      * 
    */ -public class BinaryCrossentropy extends Loss { +public class BinaryCrossentropy extends AbstractLoss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; @@ -71,70 +72,63 @@ public class BinaryCrossentropy extends Loss { private final float labelSmoothing; /** - * Creates a Binary Crossentropy Loss using {@link Class#getSimpleName()} as the loss name, {@link - * #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing and a - * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Binary Crossentropy AbstractLoss using {@link Class#getSimpleName()} as the loss + * name, {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for + * labelSmoothing and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public BinaryCrossentropy(Ops tf) { - this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); + public BinaryCrossentropy() { + this(null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); } /** * Creates a Binary Crossentropy loss using {@link Class#getSimpleName()} as the loss name, {@link * #FROM_LOGITS_DEFAULT} for fromLogits, and {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public BinaryCrossentropy(Ops tf, Reduction reduction) { - this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction); + public BinaryCrossentropy(Reduction reduction) { + this(null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction); } /** * Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name, * labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT}, a reduction of {@link - * Loss#REDUCTION_DEFAULT}, + * AbstractLoss#REDUCTION_DEFAULT}, * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public BinaryCrossentropy(Ops tf, boolean fromLogits) { - this(tf, null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); + public BinaryCrossentropy(boolean fromLogits) { + this(null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); } /** * Creates a Binary Crossentropy loss using labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT} a - * reduction of {@link Loss#REDUCTION_DEFAULT}. + * reduction of {@link AbstractLoss#REDUCTION_DEFAULT}. * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public BinaryCrossentropy(Ops tf, String name, boolean fromLogits) { - this(tf, name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); + public BinaryCrossentropy(String name, boolean fromLogits) { + this(name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); } /** * Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name, - * and a reduction of {@link Loss#REDUCTION_DEFAULT}. + * and a reduction of {@link AbstractLoss#REDUCTION_DEFAULT}. * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing * correspond to heavier smoothing. */ - public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { - this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT); + public BinaryCrossentropy(boolean fromLogits, float labelSmoothing) { + this(null, fromLogits, labelSmoothing, REDUCTION_DEFAULT); } /** - * Creates a Binary Crossentropy loss using a reduction of {@link Loss#REDUCTION_DEFAULT}. + * Creates a Binary Crossentropy loss using a reduction of {@link AbstractLoss#REDUCTION_DEFAULT}. * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, @@ -142,14 +136,13 @@ public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing * correspond to heavier smoothing. */ - public BinaryCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { - this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT); + public BinaryCrossentropy(String name, boolean fromLogits, float labelSmoothing) { + this(name, fromLogits, labelSmoothing, REDUCTION_DEFAULT); } /** * Creates a Binary Crossentropy loss * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, @@ -157,14 +150,13 @@ public BinaryCrossentropy(Ops tf, String name, boolean fromLogits, float labelSm * correspond to heavier smoothing. * @param reduction Type of Reduction to apply to the loss. */ - public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing, Reduction reduction) { - this(tf, null, fromLogits, labelSmoothing, reduction); + public BinaryCrossentropy(boolean fromLogits, float labelSmoothing, Reduction reduction) { + this(null, fromLogits, labelSmoothing, reduction); } /** * Creates a Binary Crossentropy loss * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, @@ -175,8 +167,8 @@ public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing, Redu * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public BinaryCrossentropy( - Ops tf, String name, boolean fromLogits, float labelSmoothing, Reduction reduction) { - super(tf, name, reduction); + String name, boolean fromLogits, float labelSmoothing, Reduction reduction) { + super(name, reduction); if (labelSmoothing < 0 || labelSmoothing > 1) throw new IllegalArgumentException( "labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); @@ -207,24 +199,25 @@ public BinaryCrossentropy( */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = LossesHelper.rangeCheck( - getTF(), + tf, "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + cast(tf, tf.constant(0), predictions.type()), + cast(tf, tf.constant(1), predictions.type())); } else { lPredictions = predictions; } Operand losses = - Losses.binaryCrossentropy(getTF(), labels, lPredictions, fromLogits, labelSmoothing); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Losses.binaryCrossentropy(tf, labels, lPredictions, fromLogits, labelSmoothing); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 5aac163c1e4..7d65353b004 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -37,7 +38,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.05f, 0.95f, 0f}, {0.1f, 0.8f, 0.1f}}); * CategoricalCrossentropy cce = new CategoricalCrossentropy(tf); - * Operand<TFloat32> result = cce.call(labels, predictions); + * Operand<TFloat32> result = cce.call(Ops tf, labels, predictions); * // produces 1.177 * * @@ -45,15 +46,15 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.3f, 0.7f});
    - *    Operand<TFloat32> result = cce.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = cce.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 0.814f
      * 
    * *

    Using SUM reduction type: * *

    - *    CategoricalCrossentropy cce = new CategoricalCrossentropy(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = cce.call(labels, predictions);
    + *    CategoricalCrossentropy cce = new CategoricalCrossentropy(Reduction.SUM);
    + *    Operand<TFloat32> result = cce.call(Ops tf, labels, predictions);
      *    // produces 2.354f
      * 
    * @@ -61,12 +62,12 @@ * *
      *    CategoricalCrossentropy cce =
    - *        new CategoricalCrossentropy(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = cce.call(labels, predictions);
    + *        new CategoricalCrossentropy(Reduction.NONE);
    + *    Operand<TFloat32> result = cce.call(Ops tf, labels, predictions);
      *    // produces [0.0513f, 2.303f]
      * 
    */ -public class CategoricalCrossentropy extends Loss { +public class CategoricalCrossentropy extends AbstractLoss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; public static final int DEFAULT_AXIS = Losses.CHANNELS_LAST; @@ -76,98 +77,90 @@ public class CategoricalCrossentropy extends Loss { private final int axis; /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for - * labelSmoothing, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and an axis of {@link - * #DEFAULT_AXIS} - * - * @param tf the TensorFlow Ops + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name, {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for + * labelSmoothing, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and an axis + * of {@link #DEFAULT_AXIS} */ - public CategoricalCrossentropy(Ops tf) { - this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy() { + this(null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link #FROM_LOGITS_DEFAULT} for fromLogits, - * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link #FROM_LOGITS_DEFAULT} for + * fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a AbstractLoss Reduction of + * {@link AbstractLoss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this loss */ - public CategoricalCrossentropy(Ops tf, String name) { - this(tf, name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(String name) { + this(name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name, {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for * labelSmoothing and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to loss. */ - public CategoricalCrossentropy(Ops tf, Reduction reduction) { - this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); + public CategoricalCrossentropy(Reduction reduction) { + this(null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link - * #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, and an axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss {@link #FROM_LOGITS_DEFAULT} for fromLogits, + * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this loss * @param reduction Type of Reduction to apply to loss. */ - public CategoricalCrossentropy(Ops tf, String name, Reduction reduction) { - this(tf, name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); + public CategoricalCrossentropy(String name, Reduction reduction) { + this(name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a AbstractLoss Reduction of + * {@link AbstractLoss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public CategoricalCrossentropy(Ops tf, boolean fromLogits) { - this(tf, null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(boolean fromLogits) { + this(null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link #LABEL_SMOOTHING_DEFAULT} for - * labelSmoothing, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and a channel axis of - * {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link #LABEL_SMOOTHING_DEFAULT} for + * labelSmoothing, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and a + * channel axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { - this(tf, name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(String name, boolean fromLogits) { + this(name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and a channel + * axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and * 0.9 for label 1 */ - public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { - this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(boolean fromLogits, float labelSmoothing) { + this(null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are @@ -175,15 +168,14 @@ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) *
    means that we will use a value of 0.1 for label 0 and * 0.9 for label 1 */ - public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { - this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(String name, boolean fromLogits, float labelSmoothing) { + this(name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name - * and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name and a channel axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are * smoothed, meaning the confidence on label values are relaxed. e.g. x=0.2 means @@ -191,15 +183,13 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la * for label 1 * @param reduction Type of Reduction to apply to loss. */ - public CategoricalCrossentropy( - Ops tf, boolean fromLogits, float labelSmoothing, Reduction reduction) { - this(tf, null, fromLogits, labelSmoothing, reduction, DEFAULT_AXIS); + public CategoricalCrossentropy(boolean fromLogits, float labelSmoothing, Reduction reduction) { + this(null, fromLogits, labelSmoothing, reduction, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss + * Creates a categorical cross entropy AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are @@ -213,13 +203,8 @@ public CategoricalCrossentropy( * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public CategoricalCrossentropy( - Ops tf, - String name, - boolean fromLogits, - float labelSmoothing, - Reduction reduction, - int axis) { - super(tf, name, reduction); + String name, boolean fromLogits, float labelSmoothing, Reduction reduction, int axis) { + super(name, reduction); if (labelSmoothing < 0 || labelSmoothing > 1) throw new IllegalArgumentException( "labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); @@ -251,24 +236,24 @@ public CategoricalCrossentropy( */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = LossesHelper.rangeCheck( - getTF(), + tf, "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + cast(tf, tf.constant(0), predictions.type()), + cast(tf, tf.constant(1), predictions.type())); } else { lPredictions = predictions; } Operand losses = - Losses.categoricalCrossentropy( - getTF(), labels, lPredictions, fromLogits, labelSmoothing, axis); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Losses.categoricalCrossentropy(tf, labels, lPredictions, fromLogits, labelSmoothing, axis); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index 73837ed1756..c9987fb0884 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -35,7 +36,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * CategoricalHinge categoricalHinge = new CategoricalHinge(tf); - * Operand<TFloat32> result = categoricalHinge.call(labels, predictions); + * Operand<TFloat32> result = categoricalHinge.call(Ops tf, labels, predictions); * // produces 1.4 * * @@ -43,7 +44,7 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1f, 0.f});
    - *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = categoricalHinge.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 0.6f
      * 
    * @@ -51,7 +52,7 @@ * *
      *    CategoricalHinge categoricalHinge = new CategoricalHinge(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions);
    + *    Operand<TFloat32> result = categoricalHinge.call(Ops tf, labels, predictions);
      *    // produces 2.8f
      * 
    * @@ -60,48 +61,45 @@ *
      *    CategoricalHinge categoricalHinge =
      *        new CategoricalHinge(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions);
    + *    Operand<TFloat32> result = categoricalHinge.call(Ops tf, labels, predictions);
      *    // produces [1.2f, 1.6f]
      * 
    */ -public class CategoricalHinge extends Loss { +public class CategoricalHinge extends AbstractLoss { /** - * Creates a Categorical Hinge Loss using {@link Class#getSimpleName()} as the loss name and a - * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Categorical Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public CategoricalHinge(Ops tf) { - super(tf); + public CategoricalHinge() { + super(); } /** - * Creates a Categorical Hinge Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Categorical Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public CategoricalHinge(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public CategoricalHinge(Reduction reduction) { + super(null, reduction); } /** * Creates a Categorical Hinge * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public CategoricalHinge(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public CategoricalHinge(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.categoricalHinge(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java index 0a18d93caf3..ac810139d71 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -40,7 +41,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 0.f}, {1.f, 1.f}}); * CosineSimilarity cosineLoss = new CosineSimilarity(tf); - * Operand<TFloat32> result = cosineLoss.call(labels, predictions); + * Operand<TFloat32> result = cosineLoss.call(Ops tf, labels, predictions); * // produces -0.5 * * @@ -48,7 +49,7 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
    - *    Operand<TFloat32> result = cosineLoss.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = cosineLoss.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces -0.0999f
      * 
    * @@ -56,7 +57,7 @@ * *
      *    CosineSimilarity cosineLoss = new CosineSimilarity(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = cosineLoss.call(labels, predictions);
    + *    Operand<TFloat32> result = cosineLoss.call(Ops tf, labels, predictions);
      *    // produces -0.999f
      * 
    * @@ -64,165 +65,155 @@ * *
      *    CosineSimilarity cosineLoss = new CosineSimilarity(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = cosineLoss.call(labels, predictions);
    + *    Operand<TFloat32> result = cosineLoss.call(Ops tf, labels, predictions);
      *    // produces [-0.f, -0.999f]
      * 
    */ -public class CosineSimilarity extends Loss { +public class CosineSimilarity extends AbstractLoss { public static final int DEFAULT_AXIS = -1; public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO; private final int[] axis; /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, an axis - * of {@link #DEFAULT_AXIS}, and a Loss Reduction of {@link #DEFAULT_REDUCTION} - * - * @param tf the TensorFlow Ops + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name, + * an axis of {@link #DEFAULT_AXIS}, and a AbstractLoss Reduction of {@link #DEFAULT_REDUCTION} */ - public CosineSimilarity(Ops tf) { + public CosineSimilarity() { - this(tf, null, DEFAULT_AXIS, DEFAULT_REDUCTION); + this(null, DEFAULT_AXIS, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using an axis of {@link #DEFAULT_AXIS}, and a Loss Reduction - * of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using an axis of {@link #DEFAULT_AXIS}, and a + * AbstractLoss Reduction of {@link #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param name the name of the loss */ - public CosineSimilarity(Ops tf, String name) { + public CosineSimilarity(String name) { - this(tf, name, DEFAULT_AXIS, DEFAULT_REDUCTION); + this(name, DEFAULT_AXIS, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a - * Loss Reduction of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name, + * and a AbstractLoss Reduction of {@link #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param axis The dimension along which the cosine similarity is computed. */ - public CosineSimilarity(Ops tf, int axis) { + public CosineSimilarity(int axis) { - this(tf, null, axis, DEFAULT_REDUCTION); + this(null, axis, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a - * Loss Reduction of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name, + * and a AbstractLoss Reduction of {@link #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param axis The dimension along which the cosine similarity is computed. */ - public CosineSimilarity(Ops tf, int[] axis) { + public CosineSimilarity(int[] axis) { - this(tf, null, axis, DEFAULT_REDUCTION); + this(null, axis, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using a AbstractLoss Reduction of {@link + * #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param axis The dimension along which the cosine similarity is computed. */ - public CosineSimilarity(Ops tf, String name, int axis) { + public CosineSimilarity(String name, int axis) { - this(tf, name, axis, DEFAULT_REDUCTION); + this(name, axis, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using a AbstractLoss Reduction of {@link + * #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param axis The dimension along which the cosine similarity is computed. */ - public CosineSimilarity(Ops tf, String name, int[] axis) { + public CosineSimilarity(String name, int[] axis) { - this(tf, name, axis, DEFAULT_REDUCTION); + this(name, axis, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name and an - * axis of {@link #DEFAULT_AXIS} + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, Reduction reduction) { + public CosineSimilarity(Reduction reduction) { - this(tf, null, DEFAULT_AXIS, reduction); + this(null, DEFAULT_AXIS, reduction); } /** - * Creates a Cosine Similarity Loss using an axis of {@link #DEFAULT_AXIS} + * Creates a Cosine Similarity AbstractLoss using an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, String name, Reduction reduction) { + public CosineSimilarity(String name, Reduction reduction) { - this(tf, name, DEFAULT_AXIS, reduction); + this(name, DEFAULT_AXIS, reduction); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param axis The dimension along which the cosine similarity is computed. * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, int axis, Reduction reduction) { + public CosineSimilarity(int axis, Reduction reduction) { - this(tf, null, new int[] {axis}, reduction); + this(null, new int[] {axis}, reduction); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param axis The dimension along which the cosine similarity is computed. * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, int[] axis, Reduction reduction) { + public CosineSimilarity(int[] axis, Reduction reduction) { - this(tf, null, axis, reduction); + this(null, axis, reduction); } /** - * Creates a Cosine Similarity Loss + * Creates a Cosine Similarity AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param axis The dimension along which the cosine similarity is computed. * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) { - this(tf, name, new int[] {axis}, reduction); + public CosineSimilarity(String name, int axis, Reduction reduction) { + this(name, new int[] {axis}, reduction); } /** - * Creates a Cosine Similarity Loss + * Creates a Cosine Similarity AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param axis The dimension along which the cosine similarity is computed. * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, String name, int[] axis, Reduction reduction) { - super(tf, name, reduction); + public CosineSimilarity(String name, int[] axis, Reduction reduction) { + super(name, reduction); this.axis = axis; } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.cosineSimilarity(tf, labels, predictions, axis); losses = tf.math.neg(losses); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index d4c350ef06c..05c5b47e329 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -37,7 +38,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * Hinge hingeLoss = new Hinge(tf); - * Operand<TFloat32> result = hingeLoss.call(labels, predictions); + * Operand<TFloat32> result = hingeLoss.call(Ops tf, labels, predictions); * // produces 1.3f * * @@ -45,57 +46,53 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
    - *    Operand<TFloat32> result = hingeLoss.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = hingeLoss.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 0.55f
      * 
    * *

    Using SUM reduction type: * *

    - *    Hinge hingeLoss = new Hinge(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = hingeLoss.call(labels, predictions);
    + *    Hinge hingeLoss = new Hinge(Reduction.SUM);
    + *    Operand<TFloat32> result = hingeLoss.call(Ops tf, labels, predictions);
      *    // produces 2.6f
      * 
    * *

    Using NONE reduction type: * *

    - *    Hinge hingeLoss = new Hinge(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = hingeLoss.call(labels, predictions);
    + *    Hinge hingeLoss = new Hinge(Reduction.NONE);
    + *    Operand<TFloat32> result = hingeLoss.call(Ops tf, labels, predictions);
      *    // produces [1.1f, 1.5f]
      * 
    */ -public class Hinge extends Loss { +public class Hinge extends AbstractLoss { /** - * Creates a Hinge Loss using {@link Class#getSimpleName()} as the loss name and a Loss Reduction - * of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name and a + * AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public Hinge(Ops tf) { - this(tf, null, Reduction.AUTO); + public Hinge() { + this(null, Reduction.AUTO); } /** - * Creates a Hinge Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public Hinge(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public Hinge(Reduction reduction) { + super(null, reduction); } /** * Creates a Hinge * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public Hinge(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public Hinge(String name, Reduction reduction) { + super(name, reduction); } /** @@ -122,15 +119,16 @@ public Hinge(Ops tf, String name, Reduction reduction) { */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Operand tLabels = cast(tf, labels, predictions.type()); tLabels = LossesHelper.valueCheck( - getTF(), + tf, "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); - Operand losses = Losses.hinge(getTF(), tLabels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + cast(tf, tf.constant(new int[] {-1, 0, 1}), predictions.type())); + Operand losses = Losses.hinge(tf, tLabels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index b1aee1b0656..c9a7d7edcb8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -39,7 +40,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * Huber huberLoss = new Huber(tf); - * Operand<TFloat32> result = huberLoss.call(labels, predictions); + * Operand<TFloat32> result = huberLoss.call(Ops tf, labels, predictions); * // produces 0.155 * * @@ -47,7 +48,7 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
    - *    Operand<TFloat32> result = huberLoss.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = huberLoss.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 0.09f
      * 
    * @@ -55,7 +56,7 @@ * *
      *    Huber huberLoss = new Huber(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = huberLoss.call(labels, predictions);
    + *    Operand<TFloat32> result = huberLoss.call(Ops tf, labels, predictions);
      *    // produces 0.32f
      * 
    * @@ -63,78 +64,74 @@ * *
      *    Huber huberLoss = new Huber(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = huberLoss.call(labels, predictions);
    + *    Operand<TFloat32> result = huberLoss.call(Ops tf, labels, predictions);
      *    // produces [0.18f, 0.13f]
      * 
    * * @see
    Huber loss */ -public class Huber extends Loss { +public class Huber extends AbstractLoss { public static final float DELTA_DEFAULT = 1.0f; private final float delta; /** - * Creates a Huber Loss using {@link Class#getSimpleName()} as the loss name, {@link - * #DELTA_DEFAULT} as the delta and a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Huber AbstractLoss using {@link Class#getSimpleName()} as the loss name, {@link + * #DELTA_DEFAULT} as the delta and a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT} */ - public Huber(Ops tf) { - this(tf, null, DELTA_DEFAULT, Reduction.AUTO); + public Huber() { + this(null, DELTA_DEFAULT, Reduction.AUTO); } /** - * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta and a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT} + * Creates a Huber AbstractLoss using {@link #DELTA_DEFAULT} as the delta and a AbstractLoss + * Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ - public Huber(Ops tf, String name) { - this(tf, name, DELTA_DEFAULT, Reduction.AUTO); + public Huber(String name) { + this(name, DELTA_DEFAULT, Reduction.AUTO); } /** - * Creates a Huber Loss using {@link Class#getSimpleName()} as the loss name and and {@link - * #DELTA_DEFAULT} as the delta + * Creates a Huber AbstractLoss using {@link Class#getSimpleName()} as the loss name and and + * {@link #DELTA_DEFAULT} as the delta * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public Huber(Ops tf, Reduction reduction) { - this(tf, null, DELTA_DEFAULT, reduction); + public Huber(Reduction reduction) { + this(null, DELTA_DEFAULT, reduction); } /** - * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta + * Creates a Huber AbstractLoss using {@link #DELTA_DEFAULT} as the delta * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ - public Huber(Ops tf, String name, Reduction reduction) { - this(tf, name, DELTA_DEFAULT, reduction); + public Huber(String name, Reduction reduction) { + this(name, DELTA_DEFAULT, reduction); } /** - * Creates a Huber Loss + * Creates a Huber AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param delta the point where the Huber loss function changes from quadratic to linear. * @param reduction Type of Reduction to apply to the loss. */ - public Huber(Ops tf, String name, float delta, Reduction reduction) { - super(tf, name, reduction); + public Huber(String name, float delta, Reduction reduction) { + super(name, reduction); this.delta = delta; } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.huber(getTF(), labels, predictions, delta); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.huber(tf, labels, predictions, delta); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 2aa1f72092b..ef5d88539db 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -31,8 +32,8 @@ * tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}}); * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); - * KLDivergence kld = new KLDivergence(tf); - * Operand<TFloat32> result = kld.call(labels, predictions); + * KLDivergence kld = new KLDivergence(); + * Operand<TFloat32> result = kld.call(Ops tf, labels, predictions); * // produces 0.458 * * @@ -40,68 +41,65 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
    - *    Operand<TFloat32> result = kld.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = kld.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 0.366f
      * 
    * *

    Using SUM reduction type: * *

    - *    KLDivergence kld = new KLDivergence(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = kld.call(labels, predictions);
    + *    KLDivergence kld = new KLDivergence(, Reduction.SUM);
    + *    Operand<TFloat32> result = kld.call(Ops tf, labels, predictions);
      *    // produces 0.916f
      * 
    * *

    Using NONE reduction type: * *

    - *    KLDivergence kld = new KLDivergence(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = kld.call(labels, predictions);
    + *    KLDivergence kld = new KLDivergence(, Reduction.NONE);
    + *    Operand<TFloat32> result = kld.call(Ops tf, labels, predictions);
      *    // produces [0.916f, -3.08e-06f]
      * 
    * * @see Kullback?Leibler * divergence */ -public class KLDivergence extends Loss { +public class KLDivergence extends AbstractLoss { /** - * Creates a Kullback Leibler Divergence Loss using {@link Class#getSimpleName()} as the loss name - * and a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Kullback Leibler Divergence AbstractLoss using {@link Class#getSimpleName()} as the + * loss name and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public KLDivergence(Ops tf) { - super(tf); + public KLDivergence() { + super(); } /** - * Creates a Kullback Leibler Divergence Loss Loss using {@link Class#getSimpleName()} as the loss - * name + * Creates a Kullback Leibler Divergence AbstractLoss AbstractLoss using {@link + * Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public KLDivergence(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public KLDivergence(Reduction reduction) { + super(null, reduction); } /** - * Creates a Kullback Leibler Divergence Loss + * Creates a Kullback Leibler Divergence AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public KLDivergence(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public KLDivergence(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.kullbackLeiblerDivergence(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index a11d582e527..02200c3a9e0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -33,7 +34,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {0.f, 0.f}}); * LogCosh logcosh = new LogCosh(tf); - * Operand<TFloat32> result = logcosh.call(labels, predictions); + * Operand<TFloat32> result = logcosh.call(Ops tf, labels, predictions); * // produces 0.108 * * @@ -41,74 +42,71 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
    - *    Operand<TFloat32> result = logcosh.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = logcosh.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 0.087f
      * 
    * *

    Using SUM reduction type: * *

    - *    LogCosh logcosh = new LogCosh(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = logcosh.call(labels, predictions);
    + *    LogCosh logcosh = new LogCosh(Reduction.SUM);
    + *    Operand<TFloat32> result = logcosh.call(Ops tf, labels, predictions);
      *    // produces 0.217f
      * 
    * *

    Using NONE reduction type: * *

    - *    LogCosh logcosh = new LogCosh(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = logcosh.call(labels, predictions);
    + *    LogCosh logcosh = new LogCosh(Reduction.NONE);
    + *    Operand<TFloat32> result = logcosh.call(Ops tf, labels, predictions);
      *    // produces [0.217f, 0f]
      * 
    */ -public class LogCosh extends Loss { +public class LogCosh extends AbstractLoss { /** - * Creates a LogCosh Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a LogCosh AbstractLoss using {@link Class#getSimpleName()} as the loss name and a + * AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public LogCosh(Ops tf) { - this(tf, null, Reduction.AUTO); + public LogCosh() { + this(null, Reduction.AUTO); } /** - * Creates a LogCosh Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * Creates a LogCosh AbstractLoss using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ - public LogCosh(Ops tf, String name) { - this(tf, name, Reduction.AUTO); + public LogCosh(String name) { + this(name, Reduction.AUTO); } /** - * Creates a LogCosh Loss using {@link Class#getSimpleName()} as the loss name + * Creates a LogCosh AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public LogCosh(Ops tf, Reduction reduction) { - this(tf, null, reduction); + public LogCosh(Reduction reduction) { + this(null, reduction); } /** - * Creates a LogCosh Loss + * Creates a LogCosh AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ - public LogCosh(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public LogCosh(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.logCosh(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.logCosh(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index cdd35d28aba..4dd5bce6cde 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,60 +18,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -public abstract class Loss { - public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; - - protected final Ops tf; - protected final Reduction reduction; - - /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops - */ - protected Loss(Ops tf) { - this(tf, null, Reduction.AUTO); - } - - /** - * Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops - * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. - */ - protected Loss(Ops tf, String name) { - this(tf, name, Reduction.AUTO); - } - - /** - * Creates a Loss - * - * @param tf the TensorFlow Ops - * @param name the name of this loss, if null the name will be {@link Class#getSimpleName()}. - * @param reduction Type of Reduction to apply to the loss. - */ - protected Loss(Ops tf, String name, Reduction reduction) { - this.tf = name != null ? tf.withSubScope(name) : tf.withSubScope(getClass().getSimpleName()); - this.reduction = reduction; - } - - /** - * Calculates the loss - * - * @param labels the truth values or labels - * @param predictions the predictions - * @param The data type of the predictions and loss. - * @return the loss - */ - public Operand call( - Operand labels, Operand predictions) { - return call(labels, predictions, null); - } +/** Interface for loss calc ulation */ +@FunctionalInterface +public interface Loss { /** * Generates an Operand that calculates the loss. * + * @param tf the TensorFlow Ops * @param labels the truth values or labels * @param predictions the predictions * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is @@ -84,24 +38,6 @@ public Operand call( * @param The data type of the predictions, sampleWeights and loss. * @return the loss */ - public abstract Operand call( - Operand labels, Operand predictions, Operand sampleWeights); - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - public Ops getTF() { - return tf; - } - - /** - * Gets the loss reduction - * - * @return the loss reduction - */ - public Reduction getReduction() { - return reduction; - } + Operand call( + Ops tf, Operand labels, Operand predictions, Operand sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index 03a3cf70110..d85bdf3561a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}}); * MeanAbsoluteError mae = new MeanAbsoluteError(tf); - * Operand<TFloat32> result = mae.call(labels, predictions); + * Operand<TFloat32> result = mae.call(Ops tf, labels, predictions); * // produces 0.5f * * @@ -40,64 +41,61 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
    - *    Operand<TFloat32> result = mae.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = mae.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 0.25f
      * 
    * *

    Using SUM reduction type: * *

    - *    MeanAbsoluteError mae = new MeanAbsoluteError(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = mae.call(labels, predictions);
    + *    MeanAbsoluteError mae = new MeanAbsoluteError(Reduction.SUM);
    + *    Operand<TFloat32> result = mae.call(Ops tf, labels, predictions);
      *    // produces 1.0f
      * 
    * *

    Using NONE reduction type: * *

    - *    MeanAbsoluteError mae = new MeanAbsoluteError(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = mae.call(labels, predictions);
    + *    MeanAbsoluteError mae = new MeanAbsoluteError(Reduction.NONE);
    + *    Operand<TFloat32> result = mae.call(Ops tf, labels, predictions);
      *    // produces [0.5f, 0.5f]
      * 
    */ -public class MeanAbsoluteError extends Loss { +public class MeanAbsoluteError extends AbstractLoss { /** - * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name and a - * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a MeanAbsoluteError AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public MeanAbsoluteError(Ops tf) { - super(tf); + public MeanAbsoluteError() { + super(); } /** - * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name + * Creates a MeanAbsoluteError AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public MeanAbsoluteError(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public MeanAbsoluteError(Reduction reduction) { + super(null, reduction); } /** * Creates a MeanAbsoluteError * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public MeanAbsoluteError(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.meanAbsoluteError(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 6c5242df4f2..ed5c7d73e2f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}}); * MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(tf); - * Operand<TFloat32> result = mape.call(labels, predictions); + * Operand<TFloat32> result = mape.call(Ops tf, labels, predictions); * // produces 50f * * @@ -40,64 +41,62 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
    - *    Operand<TFloat32> result = mape.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = mape.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 20f
      * 
    * *

    Using SUM reduction type: * *

    - *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = mape.call(labels, predictions);
    + *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(Reduction.SUM);
    + *    Operand<TFloat32> result = mape.call(Ops tf, labels, predictions);
      *    // produces 100.0f
      * 
    * *

    Using NONE reduction type: * *

    - *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = mape.call(labels, predictions);
    + *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(Reduction.NONE);
    + *    Operand<TFloat32> result = mape.call(Ops tf, labels, predictions);
      *    // produces [25f, 75f]
      * 
    */ -public class MeanAbsolutePercentageError extends Loss { +public class MeanAbsolutePercentageError extends AbstractLoss { /** - * Creates a MeanAbsolutePercentageError Loss using {@link Class#getSimpleName()} as the loss name - * and a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a MeanAbsolutePercentageError AbstractLoss using {@link Class#getSimpleName()} as the + * loss name and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public MeanAbsolutePercentageError(Ops tf) { - super(tf); + public MeanAbsolutePercentageError() { + super(); } /** - * Creates a MeanAbsolutePercentageError Loss using {@link Class#getSimpleName()} as the loss name + * Creates a MeanAbsolutePercentageError AbstractLoss using {@link Class#getSimpleName()} as the + * loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public MeanAbsolutePercentageError(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public MeanAbsolutePercentageError(Reduction reduction) { + super(null, reduction); } /** * Creates a MeanAbsolutePercentageError * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public MeanAbsolutePercentageError(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.meanAbsolutePercentageError(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index f975db55c44..c6898e20f20 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}}); * MeanSquaredError mse = new MeanSquaredError(tf); - * Operand<TFloat32> result = mse.call(labels, predictions); + * Operand<TFloat32> result = mse.call(Ops tf, labels, predictions); * // produces 0.5f * * @@ -40,64 +41,61 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
    - *    Operand<TFloat32> result = mse.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = mse.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 0.25f
      * 
    * *

    Using SUM reduction type: * *

    - *    MeanSquaredError mse = new MeanSquaredError(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = mse.call(labels, predictions);
    + *    MeanSquaredError mse = new MeanSquaredError(Reduction.SUM);
    + *    Operand<TFloat32> result = mse.call(Ops tf, labels, predictions);
      *    // produces 1.0f
      * 
    * *

    Using NONE reduction type: * *

    - *    MeanSquaredError mse = new MeanSquaredError(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = mse.call(labels, predictions);
    + *    MeanSquaredError mse = new MeanSquaredError(Reduction.NONE);
    + *    Operand<TFloat32> result = mse.call(Ops tf, labels, predictions);
      *    // produces [0.5f, 0.5f]
      * 
    */ -public class MeanSquaredError extends Loss { +public class MeanSquaredError extends AbstractLoss { /** - * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a MeanSquaredError AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public MeanSquaredError(Ops tf) { - super(tf); + public MeanSquaredError() { + super(); } /** - * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name + * Creates a MeanSquaredError AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public MeanSquaredError(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public MeanSquaredError(Reduction reduction) { + super(null, reduction); } /** * Creates a MeanSquaredError * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public MeanSquaredError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public MeanSquaredError(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.meanSquaredError(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index 11b8e157e90..3d325a98a6a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}}); * MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(tf); - * Operand<TFloat32> result = msle.call(labels, predictions); + * Operand<TFloat32> result = msle.call(Ops tf, labels, predictions); * // produces 0.240f * * @@ -40,64 +41,61 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
    - *    Operand<TFloat32> result = msle.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = msle.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 0.120f
      * 
    * *

    Using SUM reduction type: * *

    - *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = msle.call(labels, predictions);
    + *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(Reduction.SUM);
    + *    Operand<TFloat32> result = msle.call(Ops tf, labels, predictions);
      *    // produces 0.480f
      * 
    * *

    Using NONE reduction type: * *

    - *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = msle.call(labels, predictions);
    + *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(Reduction.NONE);
    + *    Operand<TFloat32> result = msle.call(Ops tf, labels, predictions);
      *    // produces [0.240f, 0.240f]
      * 
    */ -public class MeanSquaredLogarithmicError extends Loss { +public class MeanSquaredLogarithmicError extends AbstractLoss { /** - * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a MeanSquaredError AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public MeanSquaredLogarithmicError(Ops tf) { - super(tf); + public MeanSquaredLogarithmicError() { + super(); } /** - * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name + * Creates a MeanSquaredError AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public MeanSquaredLogarithmicError(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public MeanSquaredLogarithmicError(Reduction reduction) { + super(null, reduction); } /** * Creates a MeanSquaredError * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public MeanSquaredLogarithmicError(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.meanSquaredLogarithmicError(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index 78324acf8a5..a6eb29b7109 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {0.f, 0.f}}); * Poisson poissonLoss = new Poisson(tf); - * Operand<TFloat32> result = poissonLoss.call(labels, predictions); + * Operand<TFloat32> result = poissonLoss.call(Ops tf, labels, predictions); * // produces 0.5f * * @@ -40,74 +41,71 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
    - *    Operand<TFloat32> result = poissonLoss.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = poissonLoss.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 0.4f
      * 
    * *

    Using SUM reduction type: * *

    - *    Poisson poissonLoss = new Poisson(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = poissonLoss.call(labels, predictions);
    + *    Poisson poissonLoss = new Poisson(Reduction.SUM);
    + *    Operand<TFloat32> result = poissonLoss.call(Ops tf, labels, predictions);
      *    // produces 0.999f
      * 
    * *

    Using NONE reduction type: * *

    - *    Poisson poissonLoss = new Poisson(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = poissonLoss.call(labels, predictions);
    + *    Poisson poissonLoss = new Poisson(Reduction.NONE);
    + *    Operand<TFloat32> result = poissonLoss.call(Ops tf, labels, predictions);
      *    // produces [0.999f, 0f]
      * 
    */ -public class Poisson extends Loss { +public class Poisson extends AbstractLoss { /** - * Creates a Poisson Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Poisson AbstractLoss using {@link Class#getSimpleName()} as the loss name and a + * AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public Poisson(Ops tf) { - this(tf, null, Reduction.AUTO); + public Poisson() { + this(null, Reduction.AUTO); } /** - * Creates a Poisson Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * Creates a Poisson AbstractLoss using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ - public Poisson(Ops tf, String name) { - this(tf, name, Reduction.AUTO); + public Poisson(String name) { + this(name, Reduction.AUTO); } /** - * Creates a Poisson Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Poisson AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public Poisson(Ops tf, Reduction reduction) { - this(tf, null, reduction); + public Poisson(Reduction reduction) { + this(null, reduction); } /** - * Creates a Poisson Loss + * Creates a Poisson AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ - public Poisson(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public Poisson(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.poisson(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.poisson(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java index 87ea43c6c3a..e40ec6d6ebb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java @@ -15,7 +15,7 @@ package org.tensorflow.framework.losses; /** - * Type of Loss Reduction + * Type of AbstractLoss Reduction * *

    {@link #AUTO} indicates that the reduction option will be determined by the usage context. For * almost all cases this defaults to {@link #SUM_OVER_BATCH_SIZE}. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index d04cc67d5d9..291a91894b0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -43,7 +44,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.05f, 0.95f, 0f}, {0.1f, 0.8f, 0.1f}}); * SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(tf); - * Operand<TFloat32> result = sparseCCE.call(labels, predictions); + * Operand<TFloat32> result = sparseCCE.call(Ops tf, labels, predictions); * // produces 1.177f * * @@ -51,27 +52,27 @@ * *

      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.3f, 0.7f});
    - *    Operand<TFloat32> result = sparseCCE.call(labels, predictions, sampleWeight);
    + *    Operand<TFloat32> result = sparseCCE.call(Ops tf, labels, predictions, sampleWeight);
      *    // produces 0.814f
      * 
    * *

    Using SUM reduction type: * *

    - *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = sparseCCE.call(labels, predictions);
    + *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(Reduction.SUM);
    + *    Operand<TFloat32> result = sparseCCE.call(Ops tf, labels, predictions);
      *    // produces 2.354f
      * 
    * *

    Using NONE reduction type: * *

    - *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = sparseCCE.call(labels, predictions);
    + *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(Reduction.NONE);
    + *    Operand<TFloat32> result = sparseCCE.call(Ops tf, labels, predictions);
      *    // produces [0.0513f, 2.303f]
      * 
    */ -public class SparseCategoricalCrossentropy extends Loss { +public class SparseCategoricalCrossentropy extends AbstractLoss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final int AXIS_DEFAULT = -1; @@ -80,24 +81,23 @@ public class SparseCategoricalCrossentropy extends Loss { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link + * name, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and fromLogits={@link * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops */ - public SparseCategoricalCrossentropy(Ops tf) { - this(tf, null, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); + public SparseCategoricalCrossentropy() { + this(null, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); } /** - * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy loss using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * - * @param tf the TensorFlow Ops * @param name the name of this loss function */ - public SparseCategoricalCrossentropy(Ops tf, String name) { - this(tf, name, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(String name) { + this(name, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); } /** @@ -107,8 +107,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name) { * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to loss. */ - public SparseCategoricalCrossentropy(Ops tf, Reduction reduction) { - this(tf, null, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(Reduction reduction) { + this(null, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); } /** @@ -119,32 +119,32 @@ public SparseCategoricalCrossentropy(Ops tf, Reduction reduction) { * @param name the name of this loss function * @param reduction Type of Reduction to apply to loss. */ - public SparseCategoricalCrossentropy(Ops tf, String name, Reduction reduction) { - this(tf, name, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(String name, Reduction reduction) { + this(name, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); } /** - * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public SparseCategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { - this(tf, name, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(String name, boolean fromLogits) { + this(name, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); } /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link + * name, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} and fromLogits={@link * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits) { - this(tf, null, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(boolean fromLogits) { + this(null, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); } /** @@ -155,8 +155,8 @@ public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits) { * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param reduction Type of Reduction to apply to loss. */ - public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits, Reduction reduction) { - this(tf, null, fromLogits, reduction, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(boolean fromLogits, Reduction reduction) { + this(null, fromLogits, reduction, AXIS_DEFAULT); } /** @@ -170,8 +170,8 @@ public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits, Reduction reduc * and axis=1 corresponds to data format 'Channels First'. */ public SparseCategoricalCrossentropy( - Ops tf, String name, boolean fromLogits, Reduction reduction, int axis) { - super(tf, name, reduction); + String name, boolean fromLogits, Reduction reduction, int axis) { + super(name, reduction); this.fromLogits = fromLogits; this.axis = axis; } @@ -199,23 +199,24 @@ public SparseCategoricalCrossentropy( */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = LossesHelper.rangeCheck( - getTF(), + tf, "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + cast(tf, tf.constant(0), predictions.type()), + cast(tf, tf.constant(1), predictions.type())); } else { lPredictions = predictions; } Operand losses = - Losses.sparseCategoricalCrossentropy(getTF(), labels, lPredictions, fromLogits, axis); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Losses.sparseCategoricalCrossentropy(tf, labels, lPredictions, fromLogits, axis); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index dadbdb3b95e..c804b463984 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -37,7 +38,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * SquaredHinge squaredHinge = new SquaredHinge(tf); - * Operand<TFloat32> result = squaredHinge.call(labels, predictions); + * Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions); * // produces 1.86f * * @@ -45,7 +46,7 @@ * *
      *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
    - *    Operand<TFloat32> result = squaredHinge.call(labels, predictions,
    + *    Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions,
      *                                                  sampleWeight);
      *    // produces 0.73f
      * 
    @@ -53,50 +54,46 @@ *

    Using SUM reduction type: * *

    - *    SquaredHinge squaredHinge = new SquaredHinge(tf, Reduction.SUM);
    - *    Operand<TFloat32> result = squaredHinge.call(labels, predictions);
    + *    SquaredHinge squaredHinge = new SquaredHinge(Reduction.SUM);
    + *    Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions);
      *    // produces 3.72f
      * 
    * *

    Using NONE reduction type: * *

    - *    SquaredHinge squaredHinge = new SquaredHinge(tf, Reduction.NONE);
    - *    Operand<TFloat32> result = squaredHinge.call(labels, predictions);
    + *    SquaredHinge squaredHinge = new SquaredHinge(Reduction.NONE);
    + *    Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions);
      *    // produces [1.46f, 2.26f]
      * 
    */ -public class SquaredHinge extends Loss { +public class SquaredHinge extends AbstractLoss { /** - * Creates a Squared Hinge Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Squared Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name and a + * AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public SquaredHinge(Ops tf) { - super(tf); + public SquaredHinge() { + super(); } /** - * Creates a Squared Hinge Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Squared Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public SquaredHinge(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public SquaredHinge(Reduction reduction) { + super(null, reduction); } /** * Creates a Squared Hinge * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public SquaredHinge(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public SquaredHinge(String name, Reduction reduction) { + super(name, reduction); } /** @@ -123,19 +120,17 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + @SuppressWarnings("unchecked") - Operand tLabels = - predictions.type() == labels.type() - ? (Operand) labels - : cast(tf, labels, predictions.type()); + Operand tLabels = cast(tf, labels, predictions.type()); tLabels = LossesHelper.valueCheck( - getTF(), + tf, "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); - Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + cast(tf, tf.constant(new int[] {-1, 0, 1}), predictions.type())); + Operand losses = Losses.squaredHinge(tf, tLabels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/AbstractLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/AbstractLoss.java new file mode 100644 index 00000000000..9534f6fe3ad --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/AbstractLoss.java @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.losses.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.framework.losses.Reduction; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +public abstract class AbstractLoss implements Loss { + public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; + + protected final Reduction reduction; + private final String name; + + /** + * Creates a AbstractLoss using {@link Class#getSimpleName()} as the name and a AbstractLoss + * Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} + */ + protected AbstractLoss() { + this(null, Reduction.AUTO); + } + + /** + * Creates a AbstractLoss using a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} + * + * @param name the name of this AbstractLoss, if null the name will be {@link + * Class#getSimpleName()}. + */ + protected AbstractLoss(String name) { + this(name, Reduction.AUTO); + } + + /** + * Creates a AbstractLoss + * + * @param name the name of this loss, if null the name will be {@link Class#getSimpleName()}. + * @param reduction Type of Reduction to apply to the loss. + */ + protected AbstractLoss(String name, Reduction reduction) { + this.name = name == null ? getClass().getSimpleName() : name; + this.reduction = reduction; + } + + /** + * Calculates the loss + * + * @param tf the TensorFlow Ops + * @param labels the truth values or labels + * @param predictions the predictions + * @param The data type of the predictions and loss. + * @return the loss + */ + public Operand call( + Ops tf, Operand labels, Operand predictions) { + return call(tf, labels, predictions, null); + } + + /** + * Gets the loss reduction + * + * @return the loss reduction + */ + public Reduction getReduction() { + return reduction; + } + + /** + * Gets the name for this loss + * + * @return the name for this loss + */ + public String getName() { + return name; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index bc5047d5855..69cb2ee0dfe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -40,26 +40,26 @@ /** * Metric that computes the approximate AUC (Area under the curve) via a Riemann sum. * - *

    This metric creates four local variables, {@code truePositives}, {@code trueNegatives - * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the - * AUC. To discretize the AUC curve, a linearly spaced set of thresholds is used to compute pairs of - * recall and precision values. The area under the ROC-curve is therefore computed using the height - * of the recall values by the false positive rate, while the area under the PR-curve is the - * computed using the height of the precision values by the recall. + *

    This metric creates four local variables, {@code truePositives}, {@code trueNegatives }, + * {@code falsePositives} and {@code falseNegatives} that are used to compute the AUC. To discretize + * the AUC curve, a linearly spaced set of thresholds is used to compute pairs of recall and + * precision values. The area under the ROC-curve is therefore computed using the height of the + * recall values by the false positive rate, while the area under the PR-curve is the computed using + * the height of the precision values by the recall. * - *

    This value is ultimately returned as {@code auc}, an idempotent operation that computes - * the area under a discretized curve of precision versus recall values (computed using the + *

    This value is ultimately returned as {@code auc}, an idempotent operation that computes the + * area under a discretized curve of precision versus recall values (computed using the * aforementioned variables). The {@code numThresholds} variable controls the degree of * discretization with larger numbers of thresholds more closely approximating the true AUC. The - * quality of the approximation may vary dramatically depending on {@code numThresholds}. The - * {@code thresholds} parameter can be used to manually specify thresholds which split the - * predictions more evenly. + * quality of the approximation may vary dramatically depending on {@code numThresholds}. The {@code + * thresholds} parameter can be used to manually specify thresholds which split the predictions more + * evenly. * - *

    For best results, {@code predictions} should be distributed approximately uniformly in - * the range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor - * if this is not the case. Setting {@code summationMethod} to {@code minoring} or {@code - * majoring} can help quantify the error in the approximation by providing lower or upper - * bound estimate of the AUC. + *

    For best results, {@code predictions} should be distributed approximately uniformly in the + * range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor if + * this is not the case. Setting {@code summationMethod} to {@code minoring} or {@code majoring} can + * help quantify the error in the approximation by providing lower or upper bound estimate of the + * AUC. * *

    Usage:
    * @@ -155,8 +155,8 @@ public class AUC extends Metric { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link #DEFAULT_NUM_THRESHOLDS} for the numThresholds, {@link AUCCurve#ROC} for the curve type, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for - * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed @@ -180,8 +180,8 @@ public AUC(Ops tf, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NUM_THRESHOLDS} for the * numThresholds, {@link AUCCurve#ROC} for the curve type, {@link - * AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, - * {@code false} for multiLabel, and {@code null} for labelWeights. + * AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, {@code + * false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -206,8 +206,8 @@ public AUC(Ops tf, String name, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the - * summation method, {@code null} for thresholds, {@code false} for multiLabel, and - * {@code null} for labelWeights. + * summation method, {@code null} for thresholds, {@code false} for multiLabel, and {@code null} + * for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -233,8 +233,8 @@ public AUC(Ops tf, int numThresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the - * summation method, {@code null} for numThresholds, {@code false} for multiLabel, and - * {@code null} for labelWeights. + * summation method, {@code null} for numThresholds, {@code false} for multiLabel, and {@code + * null} for labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -259,8 +259,8 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric. using {@link AUCCurve#ROC} for the curve type, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for - * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -314,8 +314,8 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link AUCSummationMethod#INTERPOLATION} for - * the summation method, {@code null} for thresholds, {@code false} for multiLabel, and - * {@code null} for labelWeights. + * the summation method, {@code null} for thresholds, {@code false} for multiLabel, and {@code + * null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -372,8 +372,8 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for - * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -400,8 +400,8 @@ public AUC(Ops tf, int numThresholds, AUCCurve curve, long seed, Class type) /** * Creates an AUC (Area under the curve) metric using {@code null} for numThresholds, {@link - * AUCSummationMethod#INTERPOLATION} for the summation method, {@code false} for multiLabel, - * and {@code null} for labelWeights. + * AUCSummationMethod#INTERPOLATION} for the summation method, {@code false} for multiLabel, and + * {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -428,8 +428,7 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) /** * Creates an AUC (Area under the curve) metric. using {@link #DEFAULT_NAME} for the metric name,, - * {@code null} for thresholds, {@code false} for multiLabel, and {@code null} for - * labelWeights. + * {@code null} for thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -453,8 +452,8 @@ public AUC( /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, - * {@code null} for numThresholds, {@code false} for multiLabel, and {@code null} - * for labelWeights. + * {@code null} for numThresholds, {@code false} for multiLabel, and {@code null} for + * labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -487,8 +486,8 @@ public AUC( } /** - * Creates an AUC (Area under the curve) metric. using {@code null} for thresholds, {@code - * false} for multiLabel, and {@code null} for labelWeights. + * Creates an AUC (Area under the curve) metric. using {@code null} for thresholds, {@code false} + * for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -513,8 +512,8 @@ public AUC( } /** - * Creates an AUC (Area under the curve) metric. using {@code null} for the numThresholds, - * {@code false} for multiLabel, and {@code null} for labelWeights. + * Creates an AUC (Area under the curve) metric. using {@code null} for the numThresholds, {@code + * false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -560,16 +559,16 @@ public AUC( * @param summationMethod Specifies the Riemann summation method used * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. This method - * automatically brackets the provided {@code thresholds} with a (-{@link #EPSILON}) - * below and a (1 + {@link #EPSILON}) above. + * automatically brackets the provided {@code thresholds} with a (-{@link #EPSILON}) below and + * a (1 + {@link #EPSILON}) above. * @param multiLabel boolean indicating whether multilabel data should be treated as such, wherein * AUC is computed separately for each label and then averaged across labels, or (when false) * if the data should be flattened into a single label before AUC computation. In the latter * case, when multilabel data is passed to AUC, each label-prediction pair is treated as an * individual data point. Should be set to {@code false} for multi-class data. * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When {@code - * multiLabel} is true, the weights are applied to the individual label AUCs when they - * are averaged to produce the multi-label AUC. When it's false, they are used to weight the + * multiLabel} is true, the weights are applied to the individual label AUCs when they are + * averaged to produce the multi-label AUC. When it's false, they are used to weight the * individual label predictions in computing the confusion matrix on the flattened data. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -684,8 +683,8 @@ private Map> build(Shape shape) { } // Create metric variables - Zeros zeros = new Zeros<>(tf); - Operand zero = zeros.call(tf.constant(variableShape), type); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(tf, tf.constant(variableShape), type); if (truePositives == null) { truePositives = tf.withName(getTruePositivesName()).variable(zero); initializers.put(ConfusionMatrixEnum.TRUE_POSITIVES, tf.assign(truePositives, zero)); @@ -715,8 +714,8 @@ private Map> build(Shape shape) { * * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more class * dimensions, and L1 is a potential extra dimension of size 1 that would be squeezed. Will be - * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} {@code != null - * }, then Cx must be a single dimension. + * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} {@code != null }, + * then Cx must be a single dimension. * @param predictions the predictions shape (N, Cx, P1?). Will be cast to {@code T}. * @param sampleWeights sample weights to be applied to values, may be null. Will be cast to * {@code }. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java index 516d6c91ba6..b8ec681cbfc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java @@ -29,12 +29,10 @@ * Metric that calculates how often predictions equals labels. * *

    This metric creates two local variables, total and count that are used to compute the - * frequency with which {@code predictions} matches {@code labels}. This frequency is - * ultimately returned as binary accuracy: an idempotent operation that simply divides total by - * count. + * frequency with which {@code predictions} matches {@code labels}. This frequency is ultimately + * returned as binary accuracy: an idempotent operation that simply divides total by count. * - *

    If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask - * values. + *

    If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask values. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java index 0e41699e165..a03677efd43 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -26,12 +26,10 @@ * Metric that calculates how often predictions matches binary labels. * *

    This metric creates two local variables, total and count that are used to compute the - * frequency with which {@code predictions} matches {@code labels}. This frequency is - * ultimately returned as binary accuracy: an idempotent operation that simply divides total by - * count. + * frequency with which {@code predictions} matches {@code labels}. This frequency is ultimately + * returned as binary accuracy: an idempotent operation that simply divides total by count. * - *

    If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask - * values. + *

    If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask values. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java index dece2d1cd50..0cd90325e32 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -27,18 +27,17 @@ /** * Metric that calculates how often predictions matches one-hot labels. * - *

    You can provide {@code logits} of classes as {@code predictions}, since argmax of - * {@code logits} and probabilities are same. + *

    You can provide {@code logits} of classes as {@code predictions}, since argmax of {@code + * logits} and probabilities are same. * - *

    This metric creates two local variables, {@code total} and {@code count} that are - * used to compute the frequency with which {@code predictions} matches {@code labels}. - * This frequency is ultimately returned as categorical accuracy: an idempotent operation that - * simply divides total by count. + *

    This metric creates two local variables, {@code total} and {@code count} that are used to + * compute the frequency with which {@code predictions} matches {@code labels}. This frequency is + * ultimately returned as categorical accuracy: an idempotent operation that simply divides total by + * count. * - *

    {@code predictions} and {@code labels} should be passed in as vectors of - * probabilities, rather than as labels. If necessary, use {@link - * org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, Operand, OneHot.Options...)} to expand - * {@code labels} as a vector. + *

    {@code predictions} and {@code labels} should be passed in as vectors of probabilities, rather + * than as labels. If necessary, use {@link org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, + * Operand, OneHot.Options...)} to expand {@code labels} as a vector. * *

    If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index 58aa51f664c..4a32981aeeb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -29,8 +29,7 @@ * *

    This is the crossentropy metric class to be used when there are multiple label classes (2 or * more). The labels should be given as a one_hot representation. eg., When labels values are {@code - * [2, 0, 1]}, the labels Operand contains = {@code [[0, 0, 1], [1, 0, 0], [0, 1, 0]] - * }. + * [2, 0, 1]}, the labels Operand contains = {@code [[0, 0, 1], [1, 0, 0], [0, 1, 0]] }. * * @param The data type for the metric result */ @@ -52,9 +51,9 @@ public class CategoricalCrossentropy extends MeanMetricWrappe * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to * a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, - * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} - * means that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 - * } for label {@code 1} + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} means + * that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 } for label + * {@code 1} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result @@ -73,13 +72,12 @@ public CategoricalCrossentropy( * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a * probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, - * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} - * means that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 - * } for label {@code 1} + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} means + * that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 } for label + * {@code 1} * @param axis Int specifying the channels axis. {@code axis={@link Losses#CHANNELS_LAST}} - * corresponds to data format {@code channels_last}, and {@code - * axis={@link Losses#CHANNELS_FIRST}} corresponds to data format {@code - * channels_first}. + * corresponds to data format {@code channels_last}, and {@code axis={@link + * Losses#CHANNELS_FIRST}} corresponds to data format {@code channels_first}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java index 3db7fffc2e9..9f957ee6c17 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of false negatives. * - *

    If {@code sampleWeights} is given, calculates the sum of the weights of false negatives. - * This metric creates one local variable, {@code accumulator} that is used to keep track of - * the number of false negatives. + *

    If {@code sampleWeights} is given, calculates the sum of the weights of false negatives. This + * metric creates one local variable, {@code accumulator} that is used to keep track of the number + * of false negatives. * - *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code - * sampleWeights} of 0 to mask values. + *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code sampleWeights} of 0 + * to mask values. * * @param The data type for the metric result */ @@ -50,10 +50,10 @@ public FalseNegatives(Ops tf, long seed, Class type) { * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -66,10 +66,10 @@ public FalseNegatives(Ops tf, float threshold, long seed, Class type) { * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -96,10 +96,10 @@ public FalseNegatives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -113,10 +113,10 @@ public FalseNegatives(Ops tf, String name, float threshold, long seed, Class * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java index 551529b6179..a3d585dea0f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of false positives. * - *

    If {@code sampleWeights} is given, calculates the sum of the weights of false positives. - * This metric creates one local variable, {@code accumulator} that is used to keep track of - * the number of false positives. + *

    If {@code sampleWeights} is given, calculates the sum of the weights of false positives. This + * metric creates one local variable, {@code accumulator} that is used to keep track of the number + * of false positives. * - *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code - * sampleWeights} of 0 to mask values. + *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code sampleWeights} of 0 + * to mask values. * * @param The data type for the metric result */ @@ -50,10 +50,10 @@ public FalsePositives(Ops tf, long seed, Class type) { * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -66,10 +66,10 @@ public FalsePositives(Ops tf, float threshold, long seed, Class type) { * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -96,10 +96,10 @@ public FalsePositives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -113,10 +113,10 @@ public FalsePositives(Ops tf, String name, float threshold, long seed, Class * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 22baab3d6cb..04f4deb81cf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -93,11 +93,15 @@ private void init() { Shape variableShape = Shape.of(numClasses, numClasses); if (totalConfusionMatrix == null) { - Zeros zeros = new Zeros<>(getTF()); + Zeros zeros = new Zeros<>(); totalConfusionMatrix = - getTF().withName(totalCMName).variable(zeros.call(getTF().constant(variableShape), type)); + getTF() + .withName(totalCMName) + .variable(zeros.call(getTF(), getTF().constant(variableShape), type)); initializer = - getTF().assign(totalConfusionMatrix, zeros.call(getTF().constant(variableShape), type)); + getTF() + .assign( + totalConfusionMatrix, zeros.call(getTF(), getTF().constant(variableShape), type)); } } @@ -124,8 +128,8 @@ public Assign getInitializer() { * @param sampleWeights Optional weighting of each example. Defaults to 1, if null. Rank is either * 0, or the same rank as labels, and must be broadcastable to labels. * @return the Operands that updates totalConfusionMatrix variable - * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} labels rank, - * and if the predictions size is not equal to the labels size + * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} + * labels rank, and if the predictions size is not equal to the labels size */ @Override public List updateStateList( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index acf28f5b2cc..8d92b97ec5f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -28,13 +28,12 @@ /** * Computes the mean relative error by normalizing with the given values. * - *

    This metric creates two local variables, {@code total} and {@code count} that are - * used to compute the mean relative error. This is weighted by {@code sampleWeight}, and it is - * ultimately returned as mean relative error: an idempotent operation that simply divides total by - * count. + *

    This metric creates two local variables, {@code total} and {@code count} that are used to + * compute the mean relative error. This is weighted by {@code sampleWeight}, and it is ultimately + * returned as mean relative error: an idempotent operation that simply divides total by count. * - *

    If {@code sampleWeight} is {@code null}, weights default to 1. Use {@code sampleWeight} - * of 0 to mask values. + *

    If {@code sampleWeight} is {@code null}, weights default to 1. Use {@code sampleWeight} of 0 + * to mask values. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java index d88d7a4c1b4..583d9b2dde7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -85,8 +85,8 @@ public MeanTensor(Ops tf, String name, long seed, Class type) { private boolean init(Shape shape) { if (!initialized) { this.shape = shape; - Zeros zeros = new Zeros<>(getTF()); - Operand zero = zeros.call(getTF().constant(shape), type); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(getTF(), getTF().constant(shape), type); if (total == null) { total = getTF().withName(totalName).variable(zero); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index 3812e799b75..f81b32e8d76 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -36,22 +36,22 @@ /** * Computes the precision of the predictions with respect to the labels. * - *

    The metric creates two local variables, {@code truePositives} and {@code falsePositives - * } that are used to compute the precision. This value is ultimately returned as precision, - * an idempotent operation that simply divides {@code truePositives} by the sum of {@code - * truePositives} and {@code falsePositives}. + *

    The metric creates two local variables, {@code truePositives} and {@code falsePositives } that + * are used to compute the precision. This value is ultimately returned as precision, an idempotent + * operation that simply divides {@code truePositives} by the sum of {@code truePositives} and + * {@code falsePositives}. * - *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of - * 0 to mask values. + *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of 0 to mask + * values. * - *

    If {@code topK} is set, the metric calculates precision as how often on average a class - * among the top-k classes with the highest predicted values of a batch entry is correct and can be - * found in the label for that entry. + *

    If {@code topK} is set, the metric calculates precision as how often on average a class among + * the top-k classes with the highest predicted values of a batch entry is correct and can be found + * in the label for that entry. * *

    If {@code classId} is specified, the metric calculates precision by considering only the - * entries in the batch for which {@code classId} is above the {@code thresholds} and/or - * in the top-k highest predictions, and computing the fraction of them for which {@code classId - * } is indeed a correct label. + * entries in the batch for which {@code classId} is above the {@code thresholds} and/or in the + * top-k highest predictions, and computing the fraction of them for which {@code classId } is + * indeed a correct label. * * @param The data type for the metric result */ @@ -103,10 +103,9 @@ public Precision(Ops tf, String name, long seed, Class type) { * values. * * @param tf the TensorFlow Ops - * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is - * compared with prediction values to determine the truth value of predictions (i.e., above - * the threshold is true, below is false). One metric value is generated for each threshold - * value. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated for each threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -138,10 +137,9 @@ public Precision(Ops tf, float[] thresholds, long seed, Class type) { * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is - * compared with prediction values to determine the truth value of predictions (i.e., above - * the threshold is true, below is false). One metric value is generated for each threshold - * value. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated for each threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -172,10 +170,9 @@ public Precision(Ops tf, String name, float[] thresholds, long seed, Class ty * Creates a Precision Metric with a name of {@link Class#getSimpleName()} * * @param tf the TensorFlow Ops - * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is - * compared with prediction values to determine the truth value of predictions (i.e., above - * the threshold is true, below is false). One metric value is generated for each threshold - * value. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated for each threshold value. * @param topK An optional value specifying the top-k predictions to consider when calculating * precision. * @param classId Optional Integer class ID for which we want binary metrics. This must be in the @@ -216,10 +213,9 @@ public Precision( * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is - * compared with prediction values to determine the truth value of predictions (i.e., above - * the threshold is true, below is false). One metric value is generated for each threshold - * value. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated for each threshold value. * @param topK An optional value specifying the top-k predictions to consider when calculating * precision. * @param classId Optional Integer class ID for which we want binary metrics. This must be in the @@ -280,17 +276,15 @@ public Precision( /** Initializes the variables */ private void init() { Ops tf = getTF(); - Zeros zeros = new Zeros<>(tf); - Operand zero = zeros.call(tf.constant(Shape.of(thresholds.length)), type); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(tf, tf.constant(Shape.of(thresholds.length)), type); if (this.truePositives == null) { this.truePositives = tf.withName(truePositivesName).variable(zero); initializers.add(tf.assign(truePositives, zero)); } if (this.falsePositives == null) { - this.falsePositives = - tf.withName(falsePositivesName) - .variable(zero); + this.falsePositives = tf.withName(falsePositivesName).variable(zero); initializers.add(tf.assign(falsePositives, zero)); } } @@ -340,11 +334,12 @@ public List updateStateList( public Operand result() { Ops tf = getTF(); Operand result = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); - return thresholds.length == 1 - ? tf.reshape(tf.slice( - result, - tf.expandDims(tf.constant(0), tf.constant(0)), - tf.expandDims(tf.constant(1), tf.constant(0))), + return thresholds.length == 1 + ? tf.reshape( + tf.slice( + result, + tf.expandDims(tf.constant(0), tf.constant(0)), + tf.expandDims(tf.constant(1), tf.constant(0))), tf.constant(Shape.scalar())) : result; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java index 5f5f9b47a10..0bb49378f5b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -29,8 +29,8 @@ * falseNegatives that are used to compute the precision at the given recall. The threshold for the * given recall value is computed and used to evaluate the corresponding precision. * - *

    If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of - * 0 to mask values. + *

    If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of 0 to mask + * values. * * @param The data type for the metric result */ @@ -115,8 +115,7 @@ public PrecisionAtRecall( public Operand result() { Ops tf = getTF(); - Operand div = - tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); + Operand div = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); Operand sub = tf.math.sub(div, cast(tf, tf.constant(recall), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); minIndex = tf.expandDims(minIndex, tf.constant(0)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java index 3886ec050b0..2780add994f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -36,20 +36,20 @@ /** * Computes the recall of the predictions with respect to the labels. * - *

    This metric creates two local variables, {@code truePositives} and {@code falseNegatives - * }, that are used to compute the recall. This value is ultimately returned as recall, an - * idempotent operation that simply divides {@code truePositives} by the sum of {@code - * truePositives} and {@code falseNegatives}. + *

    This metric creates two local variables, {@code truePositives} and {@code falseNegatives }, + * that are used to compute the recall. This value is ultimately returned as recall, an idempotent + * operation that simply divides {@code truePositives} by the sum of {@code truePositives} and + * {@code falseNegatives}. * - *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of - * 0 to mask values. + *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of 0 to mask + * values. * - *

    If {@code topK} is set, the metric calculates recall as how often on average a class - * among the labels of a batch entry is in the top-k predictions. + *

    If {@code topK} is set, the metric calculates recall as how often on average a class among the + * labels of a batch entry is in the top-k predictions. * - *

    If {@code classId} is specified, the metric calculates recall by considering only the - * entries in the batch for which {@code classId} is in the label, and computing the fraction - * of them for which {@code classId} is above the threshold and/or in the top-k predictions. + *

    If {@code classId} is specified, the metric calculates recall by considering only the entries + * in the batch for which {@code classId} is in the label, and computing the fraction of them for + * which {@code classId} is above the threshold and/or in the top-k predictions. * * @param The data type for the metric result */ @@ -305,8 +305,8 @@ public Recall( /** Initializes the Variables */ private void init() { Ops tf = getTF(); - Zeros zeros = new Zeros<>(tf); - Operand zero = zeros.call(tf.constant(Shape.of(this.thresholds.length)), type); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(tf, tf.constant(Shape.of(this.thresholds.length)), type); if (truePositives == null) { truePositives = tf.withName(truePositivesName).variable(zero); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java index a3fc2f77b7f..e54def48fce 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -34,8 +34,8 @@ * falseNegatives that are used to compute the recall at the given precision. The threshold for the * given precision value is computed and used to evaluate the corresponding recall. * - *

    If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of - * 0 to mask values. + *

    If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of 0 to mask + * values. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java index 3886428425b..0d140eb96b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -27,8 +27,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes root mean squared error metric between {@code labels} and {@code predictions} - * . + * Computes root mean squared error metric between {@code labels} and {@code predictions} . * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java index 29c0504b823..23a529ae1bb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -25,19 +25,19 @@ /** * Computes best sensitivity where sensitivity is >= specified value. * - *

    {@code Sensitivity} measures the proportion of actual positives that are correctly - * identified as such {@code (tp / (tp + fn))}. + *

    {@code Sensitivity} measures the proportion of actual positives that are correctly identified + * as such {@code (tp / (tp + fn))}. * - *

    {@code Specificity} measures the proportion of actual negatives that are correctly - * identified as such {@code (tn / (tn + fp))}. + *

    {@code Specificity} measures the proportion of actual negatives that are correctly identified + * as such {@code (tn / (tn + fp))}. * - *

    This metric creates four local variables, {@code truePositives}, {@code trueNegatives - * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the - * sensitivity at the given specificity. The threshold for the given specificity value is computed - * and used to evaluate the corresponding sensitivity. + *

    This metric creates four local variables, {@code truePositives}, {@code trueNegatives }, + * {@code falsePositives} and {@code falseNegatives} that are used to compute the sensitivity at the + * given specificity. The threshold for the given specificity value is computed and used to evaluate + * the corresponding sensitivity. * - *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of - * 0 to mask values. + *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of 0 to mask + * values. * * @see Additional information * about specificity and sensitivity diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 5294f798044..1d017ddf8fb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -35,9 +35,9 @@ * probabilities are same. * *

    This metric creates two local variables, `total` and `count` that are used to compute the - * frequency with which {@code predictions} matches {@code labels}. This frequency is - * ultimately returned as `sparse categorical accuracy`: an idempotent operation that simply divides - * `total` by `count`. + * frequency with which {@code predictions} matches {@code labels}. This frequency is ultimately + * returned as `sparse categorical accuracy`: an idempotent operation that simply divides `total` by + * `count`. * *

    If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values.' * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java index 2cb7e54eba0..95d46c8fd06 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -24,19 +24,19 @@ /** * Computes best specificity where sensitivity is >= specified value. {@code Sensitivity} - * measures the proportion of actual positives that are correctly identified as such {@code - * (tp / (tp + fn))}. + * measures the proportion of actual positives that are correctly identified as such {@code (tp / + * (tp + fn))}. * - *

    {@code Specificity} measures the proportion of actual negatives that are correctly - * identified as such {@code (tn / (tn + fp))}. + *

    {@code Specificity} measures the proportion of actual negatives that are correctly identified + * as such {@code (tn / (tn + fp))}. * - *

    This metric creates four local variables, {@code truePositives}, {@code trueNegatives - * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the - * specificity at the given sensitivity. The threshold for the given sensitivity value is computed - * and used to evaluate the corresponding specificity. + *

    This metric creates four local variables, {@code truePositives}, {@code trueNegatives }, + * {@code falsePositives} and {@code falseNegatives} that are used to compute the specificity at the + * given sensitivity. The threshold for the given sensitivity value is computed and used to evaluate + * the corresponding specificity. * - *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of - * 0 to mask values. + *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of 0 to mask + * values. * * @see Additional information * about specificity and sensitivity diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java index 637ca6cdd05..bcb1d7b9a36 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java @@ -21,11 +21,11 @@ /** * Computes the (weighted) sum of the given values. * - *

    For example, if values is {@code [1, 3, 5, 7]} then the sum is {@code 16}. If the - * weights were specified as {@code [1, 1, 0, 0]}, then the sum would be {@code 4.} + *

    For example, if values is {@code [1, 3, 5, 7]} then the sum is {@code 16}. If the weights were + * specified as {@code [1, 1, 0, 0]}, then the sum would be {@code 4.} * - *

    This metric creates one variable, {@code total}, that is used to compute the sum of - * values. This is ultimately returned as sum. + *

    This metric creates one variable, {@code total}, that is used to compute the sum of values. + * This is ultimately returned as sum. * *

    If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java index 0146552433f..b6e50c3295a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -34,8 +34,8 @@ public class TopKCategoricalAccuracy extends MeanMetricWrappe private final int k; /** - * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for {@code k}, Number of - * top elements to look at for computing accuracy. + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for {@code k}, Number of top + * elements to look at for computing accuracy. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java index 5c65f8c469f..fd6b95df6d2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of true negatives. * - *

    If {@code sampleWeights} is given, calculates the sum of the weights of true negatives. - * This metric creates one local variable, {@code accumulator} that is used to keep track of - * the number of true negatives. + *

    If {@code sampleWeights} is given, calculates the sum of the weights of true negatives. This + * metric creates one local variable, {@code accumulator} that is used to keep track of the number + * of true negatives. * - *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code - * sampleWeights} of 0 to mask values. + *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code sampleWeights} of 0 + * to mask values. * * @param The data type for the metric result */ @@ -50,10 +50,10 @@ public TrueNegatives(Ops tf, long seed, Class type) { * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -66,10 +66,10 @@ public TrueNegatives(Ops tf, float threshold, long seed, Class type) { * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -96,10 +96,10 @@ public TrueNegatives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -113,10 +113,10 @@ public TrueNegatives(Ops tf, String name, float threshold, long seed, Class t * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java index f0dd8c42de5..90fe9142014 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of true positives. * - *

    If {@code sampleWeights} is given, calculates the sum of the weights of true positives. - * This metric creates one local variable, {@code accumulator} that is used to keep track of - * the number of true positives. + *

    If {@code sampleWeights} is given, calculates the sum of the weights of true positives. This + * metric creates one local variable, {@code accumulator} that is used to keep track of the number + * of true positives. * - *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code - * sampleWeights} of 0 to mask values. + *

    If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code sampleWeights} of 0 + * to mask values. * * @param The data type for the metric result */ @@ -50,10 +50,10 @@ public TruePositives(Ops tf, long seed, Class type) { * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -66,10 +66,10 @@ public TruePositives(Ops tf, float threshold, long seed, Class type) { * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -96,10 +96,10 @@ public TruePositives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -113,10 +113,10 @@ public TruePositives(Ops tf, String name, float threshold, long seed, Class t * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java index 88597cf85ec..b031d80d0ef 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -67,10 +67,9 @@ public ConfusionMatrixConditionCount( * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate - * @param threshold a threshold value in {@code [0, 1]}. A threshold is compared with - * prediction values to determine the truth value of predictions (i.e., above the threshold is - * {@code true}, below is {@code false}). One metric value is generated for each - * threshold value. + * @param threshold a threshold value in {@code [0, 1]}. A threshold is compared with prediction + * values to determine the truth value of predictions (i.e., above the threshold is {@code + * true}, below is {@code false}). One metric value is generated for each threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -91,10 +90,9 @@ public ConfusionMatrixConditionCount( * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate - * @param thresholds threshold values in {@code [0, 1]}. A threshold is compared with - * prediction values to determine the truth value of predictions (i.e., above the threshold is - * {@code true}, below is {@code false}). One metric value is generated for each - * threshold value. + * @param thresholds threshold values in {@code [0, 1]}. A threshold is compared with prediction + * values to determine the truth value of predictions (i.e., above the threshold is {@code + * true}, below is {@code false}). One metric value is generated for each threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -118,12 +116,13 @@ public ConfusionMatrixConditionCount( private void init() { Shape variableShape = Shape.of(this.thresholds.length); - Zeros zeros = new Zeros<>(getTF()); + Zeros zeros = new Zeros<>(); accumulator = getTF() .withName(getAccumulatorName()) - .variable(zeros.call(getTF().constant(variableShape), type)); - initializer = getTF().assign(accumulator, zeros.call(getTF().constant(variableShape), type)); + .variable(zeros.call(getTF(), getTF().constant(variableShape), type)); + initializer = + getTF().assign(accumulator, zeros.call(getTF(), getTF().constant(variableShape), type)); } /** @@ -189,7 +188,10 @@ public float[] getThresholds() { return this.thresholds; } - /** @return the accumulatorName */ + /** + * Gets the accumulatorName + * @return the accumulatorName + */ public String getAccumulatorName() { return accumulatorName; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index f89047e457d..76c21aebefc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -18,7 +18,7 @@ import org.tensorflow.types.family.TNumber; /** - * Interface for Metrics that wrap Loss functions. + * Interface for Metrics that wrap AbstractLoss functions. * * @param The data type of the predictions. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 37bdd5849ae..ec103197709 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -29,9 +29,9 @@ * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of * {@link MetricReduction#WEIGHTED_MEAN}. * - *

    The loss function calculates the loss between the {@code labels} and {@code predictions - * } then passes this loss to the {@link Mean} metric to calculate the weighted mean of the - * loss over many iterations or epochs + *

    The loss function calculates the loss between the {@code labels} and {@code predictions } then + * passes this loss to the {@link Mean} metric to calculate the weighted mean of the loss over many + * iterations or epochs * * @param The data type for the metric result */ @@ -63,7 +63,7 @@ public LossMetric getLoss() { } /** - * Sets the Loss function for this wrapper. + * Sets the AbstractLoss function for this wrapper. * * @param loss the loss function. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 40336233d21..51b8836ec83 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -59,8 +59,7 @@ public class MetricsHelper { "weights can not be broadcast to values."; /** - * Asserts that the {@code sampleWeights} can be broadcast to the same shape as {@code values - * } + * Asserts that the {@code sampleWeights} can be broadcast to the same shape as {@code values } * *

    In losses and metrics, limited weight broadcasting is supported. Weights must be either * scalar, or the same rank as the target values, with each dimension either 1, or the same as the @@ -69,8 +68,8 @@ public class MetricsHelper { * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. * @param values the values to which weights are applied. - * @return {@code Operation} with control dependencies to ensure {@code sampleWeight} - * can be broadcast to {@code values} + * @return {@code Operation} with control dependencies to ensure {@code sampleWeight} can be + * broadcast to {@code values} * @param the type of Operand * @throws NotBroadcastableException If static checks determine {@code sampleWeights} has an * incorrect shape that prohibit broadcasting to {@code values} @@ -114,10 +113,7 @@ public static Op assertBroadcastable( throw new NotBroadcastableException( String.format( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", - ASSERT_BROADCAST_ERROR_PREFIX, - i, - valuesShapeStatic, - weightsShapeStatic)); + ASSERT_BROADCAST_ERROR_PREFIX, i, valuesShapeStatic, weightsShapeStatic)); } } return tf.withSubScope("staticDimsCheckSuccess") @@ -307,24 +303,24 @@ public static List assertShapes( *

    For estimation of these metrics over a stream of data, the function creates an `update_op` * operation that updates the given variables. * - *

    {@code labels}, {@code predictions}, and {@code sampleWeight} tensors are - * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. {@code - * sampleWeight} is then broadcast to the shape of {@code predictions}. + *

    {@code labels}, {@code predictions}, and {@code sampleWeight} tensors are aligned by {@link + * LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. {@code sampleWeight} is then + * broadcast to the shape of {@code predictions}. * * @param tf the TensorFlow Ops * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding variables to update as values. If {@code multiLabel}, then the variable * shapes are (T, D), where T is the number of thresholds and D is the number of classes - * (after slicing by {@code classIndex}, if provided). If {@code multiLabels}, then - * the variable shapes are (T). + * (after slicing by {@code classIndex}, if provided). If {@code multiLabels}, then the + * variable shapes are (T). * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding initializer Operands to for {@code variablesToUpdate}. * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra * dimension of size 1 that would be squeezed. * @param predictions the predictions shape (N, Cx, P1?) - * @param thresholds thresholds in the range {@code [0, 1]}, or {@link #NEG_INF} is used when - * topK is set + * @param thresholds thresholds in the range {@code [0, 1]}, or {@link #NEG_INF} is used when topK + * is set * @param topK optional, indicates that only the top k predictions should be considered. Applied * before possibly slicing by {@code classIndex}. * @param classIndex optional, limits the prediction and labels to the specified class. This is an @@ -338,14 +334,14 @@ public static List assertShapes( * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES * without explicit multilabel handling (i.e. when the data is to be flattened). Must have - * shape (Dx), which is the same as (Cx) referenced above, except that if {@code classIndex - * } is provided, then the final dimension of Dx is 1. These weights will be broadcast - * across the 0th dimension (the examples dimension) of {@code predictions}. May be null. - * Must be null if {@code multiLabel}. + * shape (Dx), which is the same as (Cx) referenced above, except that if {@code classIndex } + * is provided, then the final dimension of Dx is 1. These weights will be broadcast across + * the 0th dimension (the examples dimension) of {@code predictions}. May be null. Must be + * null if {@code multiLabel}. * @param the data type for the variables - * @throws IllegalArgumentException If {@code predictions} and {@code labels} have - * mismatched shapes, or if {@code sampleWeight} is not null and its shape - * doesn't match {@code predictions}, or if {@code multiLabel && labelWeights != null}.. + * @throws IllegalArgumentException If {@code predictions} and {@code labels} have mismatched + * shapes, or if {@code sampleWeight} is not null and its shape doesn't match {@code + * predictions}, or if {@code multiLabel && labelWeights != null}.. * @return an op to update the given confusion matrix variables. */ @SuppressWarnings({"unchecked", "rawtypes"}) @@ -439,11 +435,13 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), if (classIndex != null) { // Slice to new shapes (N, Dx) - tLabels = tf.squeeze(tf.gather(tLabels, - tf.constant(new int[] {classIndex}), tf.constant(-1)), + tLabels = + tf.squeeze( + tf.gather(tLabels, tf.constant(new int[] {classIndex}), tf.constant(-1)), Squeeze.axis(Collections.singletonList(1L))); - tPredictions = tf.squeeze(tf.gather(tPredictions, - tf.constant(new int[] {classIndex}), tf.constant(-1)), + tPredictions = + tf.squeeze( + tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(-1)), Squeeze.axis(Collections.singletonList(1L))); } org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions); @@ -693,8 +691,7 @@ private static Operand filterTopK(Ops tf, Operand x, i // alias for mean /** - * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false - * } + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -706,8 +703,8 @@ public static Operand mean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is - * {@code false} + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is {@code + * false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -725,10 +722,9 @@ public static Operand mean( * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained - * with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is {@code + * false}, the rank of the tensor is reduced by 1 for each entry in {@code axes }. If {@code + * keepdims} is {@code true}, the reduced dimensions are retained with length 1. * @param the type of the operand * @return the mean of elements of {@code x}. */ @@ -742,10 +738,9 @@ public static Operand mean(Ops tf, Operand x, boolean * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained - * with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is {@code + * false}, the rank of the tensor is reduced by 1 for each entry in {@code axes }. If {@code + * keepdims} is {@code true}, the reduced dimensions are retained with length 1. * @param the data type of the Operand * @return the mean of elements of {@code x}. */ @@ -783,12 +778,12 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( *

    For example: * *

    {@code
    -   *     confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
    -   *          [[0 0 0 0 0]
    -   *           [0 0 1 0 0]
    -   *           [0 0 1 0 0]
    -   *           [0 0 0 0 0]
    -   *           [0 0 0 0 1]]
    +   * confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
    +   *      [[0 0 0 0 0]
    +   *       [0 0 1 0 0]
    +   *       [0 0 1 0 0]
    +   *       [0 0 0 0 0]
    +   *       [0 0 0 0 1]]
        * }
    * * Note that the possible labels are assumed to be {@code [0, 1, 2, 3,4]}, resulting in a 5x5 @@ -802,12 +797,12 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( * @param weights optional weights to be applied to the confusion matrix * @param type Data type of the confusion matrix. * @param the type of Operands - * @return A {@code Operand} of type {@code type} with shape {@code [n, n]} - * representing the confusion matrix, where {@code n} is the number of possible labels in - * the classification task. - * @throws IllegalArgumentException If both {@code predictions} and {@code labels} do - * not have compatible shapes, or if {@code weights} is not{@code null} and its - * shape is not compatible with {@code predictions}. + * @return A {@code Operand} of type {@code type} with shape {@code [n, n]} representing the + * confusion matrix, where {@code n} is the number of possible labels in the classification + * task. + * @throws IllegalArgumentException If both {@code predictions} and {@code labels} do not have + * compatible shapes, or if {@code weights} is not{@code null} and its shape is not compatible + * with {@code predictions}. */ // TODO should this be moved to FramnworkOps under math. public static Operand confusionMatrix( @@ -883,8 +878,7 @@ public static Operand confusionMatrix( } /** - * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false - * } + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -895,8 +889,8 @@ public static Operand booleanMean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is - * {@code false} + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is {@code + * false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -913,10 +907,9 @@ public static Operand booleanMean( * * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained - * with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is {@code + * false}, the rank of the tensor is reduced by 1 for each entry in {@code axes }. If {@code + * keepdims} is {@code true}, the reduced dimensions are retained with length 1. * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { @@ -929,10 +922,9 @@ public static Operand booleanMean(Ops tf, Operand x, boolean ke * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained - * with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is {@code + * false}, the rank of the tensor is reduced by 1 for each entry in {@code axes }. If {@code + * keepdims} is {@code true}, the reduced dimensions are retained with length 1. * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java index 60a6c1ea3df..e47ea4ea8e8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -87,9 +87,9 @@ protected SensitivitySpecificityBase( /** Initializes the Variables */ private void init() { Ops tf = getTF(); - Zeros zeros = new Zeros<>(tf); + Zeros zeros = new Zeros<>(); Shape varShape = Shape.of(numThresholds); - Operand zero = zeros.call(tf.constant(varShape), type); + Operand zero = zeros.call(tf, tf.constant(varShape), type); if (this.getTruePositives() == null) { @@ -228,8 +228,6 @@ public int getNumThresholds() { return numThresholds; } - - /** * Gets the thresholds * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 68157632557..0553b1edac7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -26,8 +26,8 @@ public class SetsOps { /** - * Computes set difference of elements in last dimension of {@code a} and {@code b} with - * {@code aMinusB} set to true. + * Computes set difference of elements in last dimension of {@code a} and {@code b} with {@code + * aMinusB} set to true. * *

    All but the last dimension of {@code a} and {@code b} must match * @@ -35,8 +35,8 @@ public class SetsOps { * @param a The first operand representing set {@code a} * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the * same. Elements along the last dimension contain the results of the set * operation. */ public static Operand difference(Ops tf, Operand a, Operand b) { @@ -53,8 +53,8 @@ public static Operand difference(Ops tf, Operand a, Op * @param b The other operand representing set {@code b} * @param aMinusB whether to subtract b from a, vs vice versa. * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the * same. Elements along the last dimension contain the results of the set * operation. */ public static Operand difference( @@ -69,8 +69,8 @@ public static Operand difference( * @param a The first operand representing set {@code a} * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the * same. Elements along the last dimension contain the results of the set * operation. */ public static Operand union(Ops tf, Operand a, Operand b) { @@ -84,8 +84,8 @@ public static Operand union(Ops tf, Operand a, Operand * @param a The first operand representing set {@code a} * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the * same. Elements along the last dimension contain the results of the set * operation. */ public static Operand intersection(Ops tf, Operand a, Operand b) { @@ -100,8 +100,8 @@ public static Operand intersection(Ops tf, Operand a, * @param b The other et operation operand * @param setOperation The set operation to perform, {@link Operation}. * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the same. Elements along the last dimension contain the results of the set * operation. */ public static Operand setOperation( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java index d28185ae041..7c3fda07ea9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java @@ -21,35 +21,72 @@ import java.util.Arrays; import java.util.List; +/** + * A class that represents a Symbolic shape. + * + *

    A Symbolic shape uses symbols to identify the relationship of the shape of an operand to + * underlying values that are not know until compute time. For example, "N" represent the number of + * examples, while "L" represents the number of labels. When the values later become known, the + * shape of the operand must conform the these symbolic values. + * + * @param The data type for the Operand. + */ public class SymbolicShape { private Operand operand; private List symbols = new ArrayList<>(); + /** + * Creates a SymbolicShape + * + * @param operand the Operand that needs to conform to the shape + * @param symbols the symbolic value for each dimension of the shape. + */ public SymbolicShape(Operand operand, String... symbols) { this.operand = operand; this.symbols.addAll(Arrays.asList(symbols)); } - /** @return the operand */ + /** + * Gets the operand + * + * @return the operand + */ public Operand getOperand() { return operand; } - /** @param operand the operand to set */ + /** + * Sets the operand + * + * @param operand the operand to set + */ public void setOperand(Operand operand) { this.operand = operand; } - /** @return the symbols */ + /** + * Gets the symbols associated with each dimension of the shape + * + * @return the symbols associated with each dimension of the shape + */ public List getSymbols() { return symbols; } - /** @param symbols the symbols to set */ + /** + * Sets teh symbols associated with each dimension of the shape + * + * @param symbols the symbols associated with each dimension of the shape + */ public void setSymbols(List symbols) { this.symbols = symbols; } + /** + * Gets the rank associated with this Symbolic Shape + * + * @return the rank associated with this Symbolic Shape + */ public int rank() { return this.symbols.size(); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index 6583465da2e..18b11700380 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -32,8 +32,8 @@ /** * Weight broadcasting operations. * - *

    In {@link org.tensorflow.framework.losses} and `{@link org.tensorflow.framework.metrics}, we support limited weight broadcasting. This file includes - * operations for those broadcasting rules. + *

    In {@link org.tensorflow.framework.losses} and `{@link org.tensorflow.framework.metrics}, we + * support limited weight broadcasting. This file includes operations for those broadcasting rules. */ public class WeightsBroadcastOps { @@ -46,10 +46,11 @@ public class WeightsBroadcastOps { * @param tf the TensorFlow Ops * @param weights the weights Operand * @param values Operand of values to which weights are applied. - * @return {@code Operation} raising a tensorflow InvalidArgumentError if {@code weights} has incorrect shape. {@link NoOp} if - * static checks determine {@code weights} has correct shape. + * @return {@code Operation} raising a tensorflow InvalidArgumentError if {@code weights} has + * incorrect shape. {@link NoOp} if static checks determine {@code weights} has correct shape. * @param the type of weights and values - * @throws IllegalArgumentException If static checks determine {@code weights} has incorrect shape. + * @throws IllegalArgumentException If static checks determine {@code weights} has incorrect + * shape. */ public static Op assertBroadcastable( Ops tf, Operand weights, Operand values) { @@ -81,14 +82,12 @@ public static Op assertBroadcastable( } for (int i = 0; i < valuesRankStatic; i++) { - if (weightsShapeStatic.size(i) != 1 && valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { + if (weightsShapeStatic.size(i) != 1 + && valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { throw new IllegalArgumentException( String.format( "%s Mismatch at dim %s. values.shape=%s weights.shape=%s.", - ASSERT_BROADCASTABLE_ERROR_PREFIX, - i, - valuesShapeStatic, - weightsShapeStatic)); + ASSERT_BROADCASTABLE_ERROR_PREFIX, i, valuesShapeStatic, weightsShapeStatic)); } } return tf.withSubScope("staticDimsCheckSuccess") @@ -105,12 +104,12 @@ public static Op assertBroadcastable( tf.constant("values.shape="), valuesShape, tf.constant("isScalar="), - isScalar); + isScalar); Operand isValidShape = tf.select( - isScalar, - isScalar, + isScalar, + isScalar, hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); return tf.assertThat(isValidShape, data); @@ -140,7 +139,8 @@ private static Operand hasValidNonscalarShape( } /** - * Checks that each dimension of the two shapes are the same size, or that the weight dimension size is 1. + * Checks that each dimension of the two shapes are the same size, or that the weight dimension + * size is 1. * * @param tf the TensorFlow Ops * @param weightsShape the shape of the weights @@ -152,7 +152,8 @@ private static Operand hasValidDims( tf = tf.withSubScope("hasInvalidDims"); Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); - Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); + Operand validDims = + tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); Operand weightsShape2d = tf.expandDims(weightsShape, tf.constant(-1)); Operand invalidDims = SetsOps.difference(tf, weightsShape2d, validDims); @@ -164,8 +165,7 @@ private static Operand hasValidDims( * Broadcast {@code weights} to the same shape as {@code values}. * *

    This returns a version of {@code weights} following the same broadcast rules as {@code - * mul(weights, - * values)}, but limited to the weights shapes allowed by {@code assertBroadcastable} + * mul(weights, values)}, but limited to the weights shapes allowed by {@code assertBroadcastable} * When computing a weighted average, use this function to broadcast {@code weights} before * summing them; e.g., {@code reduceSum(w * v) / reduceSum(_broadcast_weights(w, v))}. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/AbstractRegularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/AbstractRegularizer.java new file mode 100644 index 00000000000..25535292db3 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/AbstractRegularizer.java @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.framework.losses.impl.AbstractLoss; + +/** + * Base class for Regularizers + * + *

    Regularizers allow you to apply penalties on layer parameters or layer activity during + * optimization. These penalties are summed into the loss function that the network optimizes. + */ +public abstract class AbstractRegularizer implements Regularizer { + + public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; + + private final String name; + + /** Creates a AbstractRegularizer, using {@link Class#getSimpleName()} for the name */ + protected AbstractRegularizer() { + this(null); + } + /** + * Creates a AbstractRegularizer + * + * @param name the name of this regularizer, if null use {@link Class#getSimpleName()} for the + * name. + */ + protected AbstractRegularizer(String name) { + this.name = name == null ? this.getClass().getSimpleName() : name; + } + + /** + * Returns this AbstractRegularizer as a AbstractLoss This is a convenience to use regularize a + * loss. Only sampleWeights are applied to the regularizer. + * + * @return this AbstractRegularizer as a AbstractLoss + */ + public AbstractLoss asLoss() { + return new RegularizerLoss(this); + } + + /** + * Gets the name for this regularizer + * + * @return the name for this regularizer + */ + public String getName() { + return name; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java index 7c8c2a1360a..4b7aa1af620 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -14,8 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.regularizers; -import org.tensorflow.op.Ops; - /** * A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) * Regression, regularization penalty. @@ -24,24 +22,43 @@ */ public class L1 extends L1L2 { + /** + * Create a regularizer that applies an L1 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} and a name based on the class name. + */ + public L1() { + this(null, DEFAULT_REGULARIZATION_PENALTY); + } + /** * Create a regularizer that applies an L1 regularization penalty of {@link * #DEFAULT_REGULARIZATION_PENALTY} * - * @param tf the TensorFlow Ops + * @param name the name for this AbstractRegularizer + */ + public L1(String name) { + this(name, DEFAULT_REGULARIZATION_PENALTY); + } + + /** + * Create a regularizer that applies an L1 regularization penalty and a name based on the class + * name. + * + * @param l1 the L1 regularization penalty + * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. */ - public L1(Ops tf) { - this(tf, DEFAULT_REGULARIZATION_PENALTY); + public L1(float l1) { + this(null, l1); } /** * Create a regularizer that applies an L1 regularization penalty * - * @param tf the TensorFlow Ops + * @param name the name for this AbstractRegularizer * @param l1 the L1 regularization penalty * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. */ - public L1(Ops tf, float l1) { - super(tf, l1, 0f); + public L1(String name, float l1) { + super(name, l1, 0f); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index 29e411f9897..6dfaf3f0d47 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -19,6 +19,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A regularizer that applies both L1 and L2 regularization penalties. * @@ -29,33 +31,39 @@ *

    The L2 regularization penalty is computed as * *

    loss = l2 * reduceSum(square(x))
    - * */ -public class L1L2 extends Regularizer { +public class L1L2 extends AbstractRegularizer { private final float l1; private final float l2; + /** Creates an L1L2 regularizer with no l1 or l2 penalty with zero penalty */ + public L1L2() { + this(DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); + } + /** - * Creates an L1L2 regularizer with no l1 or l2 penalty with zero penalty + * Creates an L1L2 regularizer * - * @param tf the TensorFlow Ops + * @param l1 L1 regularization factor, if null it is set to 0. + * @param l2 L2 regularization factor, if null it is set to 0. + * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} + * of {@link Float#isInfinite} */ - public L1L2(Ops tf) { - this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); + public L1L2(float l1, float l2) { + this(null, l1, l2); } /** * Creates an L1L2 regularizer * - * @param tf the TensorFlow Ops * @param l1 L1 regularization factor, if null it is set to 0. * @param l2 L2 regularization factor, if null it is set to 0. * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} * of {@link Float#isInfinite} */ - public L1L2(Ops tf, float l1, float l2) { - super(tf); + public L1L2(String name, float l1, float l2) { + super(name); if (Float.isNaN(l1) || Float.isInfinite(l1)) { throw new IllegalArgumentException( String.format( @@ -73,25 +81,23 @@ public L1L2(Ops tf, float l1, float l2) { this.l2 = l2; } - /** {@inheritDoc} */ @Override - public Operand call(Operand input) { - Ops tf = getTF(); + public Operand call(Ops tf, Operand input) { if (this.getL1() == 0f && this.getL2() == 0f) { - return tf.dtypes.cast(tf.constant(0), input.type()); + return cast(tf, tf.constant(0), input.type()); } - Operand regularization = tf.dtypes.cast(tf.constant(0), input.type()); + Operand regularization = cast(tf, tf.constant(0), input.type()); if (this.getL1() != 0.f) { - Operand l1Op = tf.dtypes.cast(tf.constant(this.getL1()), input.type()); + Operand l1Op = cast(tf, tf.constant(this.getL1()), input.type()); Operand abs = tf.math.abs(input); Operand reduceSum = tf.reduceSum(abs, LossesHelper.allAxes(tf, input)); regularization = tf.math.add(regularization, tf.math.mul(l1Op, reduceSum)); } if (this.getL2() != 0.f) { - Operand l2Op = tf.dtypes.cast(tf.constant(this.getL2()), input.type()); + Operand l2Op = cast(tf, tf.constant(this.getL2()), input.type()); Operand sqr = tf.math.square(input); Operand reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input)); regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java index 7b8f5b28a70..9092b80b08f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -14,8 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.regularizers; -import org.tensorflow.op.Ops; - /** * A regularizer that applies a L2 (Ridge Regression) regularization penalty. * @@ -23,24 +21,43 @@ */ public class L2 extends L1L2 { + /** + * Create a regularizer that applies an L2 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} and a name based on the class name. + */ + public L2() { + this(null, DEFAULT_REGULARIZATION_PENALTY); + } + /** * Create a regularizer that applies an L2 regularization penalty of {@link * #DEFAULT_REGULARIZATION_PENALTY} * - * @param tf the TensorFlow Ops + * @param name the name for this AbstractRegularizer + */ + public L2(String name) { + this(name, DEFAULT_REGULARIZATION_PENALTY); + } + + /** + * Create a regularizer that applies an L1 regularization penalty and a name based on the class + * name. + * + * @param l2 the L2 regularization penalty + * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. */ - public L2(Ops tf) { - this(tf, DEFAULT_REGULARIZATION_PENALTY); + public L2(float l2) { + this(null, l2); } /** * Create a regularizer that applies an L1 regularization penalty * - * @param tf the TensorFlow Ops + * @param name the name for this AbstractRegularizer * @param l2 the L2 regularization penalty * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. */ - public L2(Ops tf, float l2) { - super(tf, 0f, l2); + public L2(String name, float l2) { + super(name, 0f, l2); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java index 5d9ff0e3e10..085f06e115c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,77 +15,18 @@ package org.tensorflow.framework.regularizers; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.Loss; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * Base class for Regularizers - * - *

    Regularizers allow you to apply penalties on layer parameters or layer activity during - * optimization. These penalties are summed into the loss function that the network optimizes. - */ -public abstract class Regularizer { - - public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; - - private final Ops tf; - private final String name; - - /** - * Creates a Regularizer, using {@link Class#getSimpleName()} for the name - * - * @param tf the TensorFlow ops. - */ - protected Regularizer(Ops tf) { - this(tf, null); - } - /** - * Creates a Regularizer - * - * @param tf the TensorFlow ops. - * @param name the name of this regularizer, if null use {@link Class#getSimpleName()} for the - * name. - */ - protected Regularizer(Ops tf, String name) { - this.tf = tf; - this.name = name == null ? this.getClass().getSimpleName() : name; - } - - /** - * Returns this Regularizer as a Loss This is a convenience to use regularize a loss. Only - * sampleWeights are applied to the regularizer. - * - * @return this Regularizer as a Loss - */ - public Loss asLoss() { - return new RegularizerLoss(this.tf, this); - } +public interface Regularizer { /** * Computes a regularization penalty from an input. * + * @param tf the TensorFlow Ops * @param input the weighted input * @return the result of computing the regularization penalty * @param the data type of the input and result */ - public abstract Operand call(Operand input); - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - public Ops getTF() { - return tf; - } - - /** - * Gets the name for this regularizer - * - * @return the name for this regularizer - */ - public String getName() { - return name; - } + Operand call(Ops tf, Operand input); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java index 582cd038f8f..11c7ee492e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java @@ -15,50 +15,49 @@ package org.tensorflow.framework.regularizers; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.Loss; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; /** - * A Regularizer call wrapped as a Loss instance + * A AbstractRegularizer call wrapped as a AbstractLoss instance * *

    This class facilitates using a regularizer as a loss, only sampleWeights are * regularized. */ -class RegularizerLoss extends Loss { +class RegularizerLoss extends AbstractLoss { - private final Regularizer regularizer; + private final AbstractRegularizer regularizer; /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT} + * Creates a AbstractLoss using {@link Class#getSimpleName()} as the name and a AbstractLoss + * Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops * @param regularizer the regularizer used to calculate the loss */ - public RegularizerLoss(Ops tf, Regularizer regularizer) { - this(tf, null, regularizer); + public RegularizerLoss(AbstractRegularizer regularizer) { + this(null, regularizer); } /** - * Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * Creates a AbstractLoss using a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops - * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. + * @param name the name of this AbstractLoss, if null the name will be {@link + * Class#getSimpleName()}. * @param regularizer the regularizer used to calculate the loss */ - public RegularizerLoss(Ops tf, String name, Regularizer regularizer) { - super(tf, name); + public RegularizerLoss(String name, AbstractRegularizer regularizer) { + super(name); this.regularizer = regularizer; } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { if (sampleWeights == null) { throw new IllegalArgumentException("sampleWeights cannot be null"); } - return regularizer.call(sampleWeights); + return regularizer.call(tf, sampleWeights); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java index 914b94dfada..9f3fa75e95d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java @@ -14,36 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - -/** @author Jim Clarke */ public class ELUTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ELUTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of ELU call method */ @Test public void testCallFloat() { @@ -52,8 +33,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ELU instance = new ELU<>(tf); - Operand result = instance.call(tf.constant(input)); + ELU instance = new ELU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -66,8 +47,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ELU instance = new ELU<>(tf); - Operand result = instance.call(tf.constant(input)); + ELU instance = new ELU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -80,8 +61,8 @@ public void testAlpha() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ELU instance = new ELU<>(tf, 2.0f); - Operand result = instance.call(tf.constant(input)); + ELU instance = new ELU<>(2.0f); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java index 1157c582168..f82c19987d1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java @@ -14,35 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class ExponentialTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ExponentialTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of Exponential call method. */ @Test public void testCallFloat() { @@ -60,8 +42,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Exponential instance = new Exponential<>(tf); - Operand result = instance.call(tf.constant(input)); + Exponential instance = new Exponential<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -78,8 +60,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Exponential instance = new Exponential<>(tf); - Operand result = instance.call(tf.constant(input)); + Exponential instance = new Exponential<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java index 35f57c47f66..0e32201c3e6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java @@ -14,35 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class HardSigmoidTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public HardSigmoidTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of HardSigmoid call method. */ @Test public void testCallFloat() { @@ -51,8 +33,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - HardSigmoid instance = new HardSigmoid<>(tf); - Operand result = instance.call(tf.constant(input)); + HardSigmoid instance = new HardSigmoid<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -65,8 +47,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - HardSigmoid instance = new HardSigmoid<>(tf); - Operand result = instance.call(tf.constant(input)); + HardSigmoid instance = new HardSigmoid<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java index 7974035c680..817940688e8 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -26,20 +26,6 @@ public class LinearTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public LinearTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of Linear call method. */ @Test public void testCallInt() { @@ -48,8 +34,8 @@ public void testCallInt() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Linear instance = new Linear<>(tf); - Operand result = instance.call(tf.constant(input)); + Linear instance = new Linear<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -62,8 +48,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Linear instance = new Linear<>(tf); - Operand result = instance.call(tf.constant(input)); + Linear instance = new Linear<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -76,8 +62,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Linear instance = new Linear<>(tf); - Operand result = instance.call(tf.constant(input)); + Linear instance = new Linear<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java index a0aa2c4b453..94f803d6b1c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java @@ -14,30 +14,20 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; /** @author Jim Clarke */ public class ReLUTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ReLUTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of ReLU call method */ @Test public void testCallFloat() { @@ -46,8 +36,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -60,8 +50,8 @@ public void testCallInt() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -74,8 +64,8 @@ public void testCallLong() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -88,9 +78,9 @@ public void testCallFloat16() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); + ReLU instance = new ReLU<>(); Operand result = - instance.call(tf.dtypes.cast(tf.constant(input), TFloat16.class)); + instance.call(tf, tf.dtypes.cast(tf.constant(input), TFloat16.class)); session.evaluate(tf.dtypes.cast(tf.constant(expected), TFloat16.class), result); } } @@ -103,8 +93,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -112,12 +102,12 @@ public void testCallDouble() { @Test public void testAlpha() { double[] input = {-10., -5., 0.0, 5., 10.}; - double[] expected = {-5. , -2.5, 0., 5., 10.}; + double[] expected = {-5., -2.5, 0., 5., 10.}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf, 0.5f, ReLU.MAX_VALUE_DEFAULT, ReLU.THRESHOLD_DEFAULT); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(0.5f, ReLU.MAX_VALUE_DEFAULT, ReLU.THRESHOLD_DEFAULT); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -129,8 +119,8 @@ public void testMaxValue() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf, ReLU.ALPHA_DEFAULT, 5, ReLU.THRESHOLD_DEFAULT); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(ReLU.ALPHA_DEFAULT, 5, ReLU.THRESHOLD_DEFAULT); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -138,12 +128,12 @@ public void testMaxValue() { @Test public void testThreshold() { double[] input = {-10., -5., 0.0, 5., 10.}; - double[] expected = {-0., -0., 0., 0., 10.}; + double[] expected = {-0., -0., 0., 0., 10.}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf, ReLU.ALPHA_DEFAULT, ReLU.MAX_VALUE_DEFAULT, 5.0f); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(ReLU.ALPHA_DEFAULT, ReLU.MAX_VALUE_DEFAULT, 5.0f); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java index 8bad6f1f066..ef4644df18e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java @@ -14,35 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SELUTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SELUTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of SELU call method */ @Test public void testCallFloat() { @@ -53,8 +35,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SELU instance = new SELU<>(tf); - Operand result = instance.call(tf.constant(input)); + SELU instance = new SELU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -71,8 +53,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SELU instance = new SELU<>(tf); - Operand result = instance.call(tf.constant(input)); + SELU instance = new SELU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java index 9dca622c3ec..0c59eeaba6e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java @@ -14,34 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SigmoidTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SigmoidTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - /** Test of Sigmoid call method */ @Test public void testCallFloat() { @@ -59,8 +42,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Sigmoid instance = new Sigmoid<>(tf); - Operand result = instance.call(tf.constant(input)); + Sigmoid instance = new Sigmoid<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -77,8 +60,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Sigmoid instance = new Sigmoid<>(tf); - Operand result = instance.call(tf.constant(input)); + Sigmoid instance = new Sigmoid<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java index 05ec3a4f716..aeb971905a2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java @@ -14,35 +14,18 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SoftmaxTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SoftmaxTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - /** Test of Softmax method, of class Activations. */ @Test public void testSoftmaxOpsOperandFloat() { @@ -54,8 +37,8 @@ public void testSoftmaxOpsOperandFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -71,8 +54,8 @@ public void testSoftmaxOpsOperandDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -88,8 +71,8 @@ public void testSoftmaxOpsOperandDoubleNegative() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -99,14 +82,14 @@ public void testSoftmaxOpsOperandDoubleNegative() { public void testSoftmax1D() { double[] input = {1, -2, 3, -4, -5, 6, 7, 8}; double[] expected = { - 6.0352829e-04, 3.0047902e-05, 4.4595040e-03, 4.0665414e-06, - 1.4959969e-06, 8.9571528e-02, 2.4348068e-01, 6.6184908e-01 + 6.0352829e-04, 3.0047902e-05, 4.4595040e-03, 4.0665414e-06, + 1.4959969e-06, 8.9571528e-02, 2.4348068e-01, 6.6184908e-01 }; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -116,14 +99,14 @@ public void testSoftmax1D() { public void testSoftmax3D() { double[][][] input = {{{1, -2}, {3, -4}}, {{-5, 6}, {-7, 8}}}; double[][][] expected = { - {{9.5257413e-01, 4.7425874e-02}, {9.9908900e-01, 9.1105123e-04}}, - {{1.6701422e-05, 9.9998331e-01}, {3.0590220e-07, 9.9999964e-01}} + {{9.5257413e-01, 4.7425874e-02}, {9.9908900e-01, 9.1105123e-04}}, + {{1.6701422e-05, 9.9998331e-01}, {3.0590220e-07, 9.9999964e-01}} }; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java index a17f2650d62..e896807d9f7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -26,20 +26,6 @@ public class SoftplusTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SoftplusTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of Softplus call method */ @Test public void testCallFloat() { @@ -50,8 +36,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softplus instance = new Softplus<>(tf); - Operand result = instance.call(tf.constant(input)); + Softplus instance = new Softplus<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -68,8 +54,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softplus instance = new Softplus<>(tf); - Operand result = instance.call(tf.constant(input)); + Softplus instance = new Softplus<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java index 43591ab4761..2f9a17caf59 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -26,20 +26,6 @@ public class SoftsignTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SoftsignTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of Softsign call method */ @Test public void testCallFloat() { @@ -48,8 +34,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softsign instance = new Softsign<>(tf); - Operand result = instance.call(tf.constant(input)); + Softsign instance = new Softsign<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -71,8 +57,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softsign instance = new Softsign<>(tf); - Operand result = instance.call(tf.constant(input)); + Softsign instance = new Softsign<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java index 7576789320b..8dabfaf379a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java @@ -14,35 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SwishTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SwishTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of Swish call method */ @Test public void testCallFloat() { @@ -60,8 +42,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Swish instance = new Swish<>(tf); - Operand result = instance.call(tf.constant(input)); + Swish instance = new Swish<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -83,8 +65,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Swish instance = new Swish<>(tf); - Operand result = instance.call(tf.constant(input)); + Swish instance = new Swish<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java index 5162e141c44..3988ec55bb3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -25,20 +25,6 @@ public class TanhTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public TanhTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of Tanh call method. */ @Test public void testCallFloat() { @@ -52,8 +38,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Tanh instance = new Tanh<>(tf); - Operand result = instance.call(tf.constant(input)); + Tanh instance = new Tanh<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -71,8 +57,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Tanh instance = new Tanh<>(tf); - Operand result = instance.call(tf.constant(input)); + Tanh instance = new Tanh<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java index 1f80388e88f..259d6a963b5 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java @@ -35,8 +35,8 @@ public void testCall() { for (AtomicInteger i = new AtomicInteger(); i.get() < testValues.length; i.getAndIncrement()) { - MaxNorm instance = new MaxNorm(tf, testValues[i.get()]); - Operand result = instance.call(weights); + MaxNorm instance = new MaxNorm(testValues[i.get()]); + Operand result = instance.call(tf, weights); session.evaluate(result, v -> v.floatValue() <= testValues[i.get()]); } } @@ -47,13 +47,13 @@ public void testCall1() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MaxNorm instance = new MaxNorm(tf, 2.0); + MaxNorm instance = new MaxNorm(2.0); Operand weights = tf.constant( new float[][] { {0, 1, 3, 3}, {0, 0, 0, 3}, {0, 0, 0, 3}, }); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); float[] expected = { 0, 1, 2, 1.1547005f, 0, 0, 0, 1.1547005f, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java index 8c2c3a54ff9..8b4c4007096 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java @@ -39,8 +39,8 @@ public void testCall() { for (AtomicInteger i = new AtomicInteger(); i.get() < testValues.length; i.getAndIncrement()) { - MinMaxNorm instance = new MinMaxNorm(tf, testValues[i.get()], testValues[i.get()] * 2); - Operand result = instance.call(weights); + MinMaxNorm instance = new MinMaxNorm(testValues[i.get()], testValues[i.get()] * 2); + Operand result = instance.call(tf, weights); if (tfMode == TestSession.Mode.EAGER) evaluate(session, result.asTensor(), testValues[i.get()]); else diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java index 6a6fdc13536..1a24c188860 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java @@ -17,8 +17,8 @@ public void testTFloat32() { Ops tf = session.getTF(); float[][] array = {{-1, 2, -3, 4}, {-10, 11, 12, -13}}; Operand weights = tf.constant(array); - NonNeg instance = new NonNeg(tf); - Operand result = instance.call(weights); + NonNeg instance = new NonNeg(); + Operand result = instance.call(tf, weights); float[] expected = {0, 2, 0, 4, 0, 11, 12, 0}; session.evaluate(expected, result); } @@ -31,8 +31,8 @@ public void testTFloat64() { Ops tf = session.getTF(); final double[][] array = {{-1, 2, -3, 4}, {-10, 11, 12, -13}}; Operand weights = tf.constant(array); - NonNeg instance = new NonNeg(tf); - Operand result = instance.call(weights); + NonNeg instance = new NonNeg(); + Operand result = instance.call(tf, weights); double[] expected = {0, 2, 0, 4, 0, 11, 12, 0}; session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java index 6437ebcd760..9c784b7f31e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java @@ -28,8 +28,8 @@ public void testTFloat32() { }; Operand weights = tf.constant(array); - UnitNorm instance = new UnitNorm(tf, 1); - Operand result = instance.call(weights); + UnitNorm instance = new UnitNorm(1); + Operand result = instance.call(tf, weights); Operand expected = tf.constant(expectedArray); session.evaluate(expected, result); } @@ -50,9 +50,9 @@ public void testCallTFloat64() { {{0.72920675, 0.40984813, 0.55712338}, {0.68429305, 0.91215323, 0.83042956}}, {{0.97694125, 0.99972269, 0.13576831}, {0.21350717, 0.02353181, 0.99074035}} }; - UnitNorm instance = new UnitNorm(tf, 1); + UnitNorm instance = new UnitNorm(1); Operand weights = tf.constant(array); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); Operand expected = tf.constant(expectedArray); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java index 4e81e0620e6..9291e5f83ef 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java @@ -14,12 +14,18 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; @@ -29,20 +35,6 @@ public class ConstantTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ConstantTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Constant. */ @Test public void testCallUInt() { @@ -51,8 +43,9 @@ public void testCallUInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 0xf); - Operand operand = instance.call(tf.constant(shape), TUint8.class); + Constant instance = new Constant<>(0xf); + + Operand operand = instance.call(tf, tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -67,8 +60,9 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 0xf); - Operand operand = instance.call(tf.constant(shape), TInt32.class); + Constant instance = new Constant<>(0xf); + + Operand operand = instance.call(tf, tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -83,8 +77,9 @@ public void testCallLong() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 0xffL); - Operand operand = instance.call(tf.constant(shape), TInt64.class); + Constant instance = new Constant<>(0xffL); + + Operand operand = instance.call(tf, tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -97,8 +92,9 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 12.F); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Constant instance = new Constant<>(12.F); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -112,8 +108,9 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 11.); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Constant instance = new Constant<>(11.); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -129,8 +126,9 @@ public void testCallString() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 22); - instance.call(tf.constant(shape), TString.class); + Constant instance = new Constant<>(22); + + instance.call(tf, tf.constant(shape), TString.class); fail("IllegalArgumentException should have been thrown for TString"); } }); @@ -145,8 +143,9 @@ public void testCallBool() { Shape shape = Shape.of(2, 2); Boolean[] expected = {true, true, true, true}; - Constant instance = new Constant<>(tf, true); - Operand operand = instance.call(tf.constant(shape), TBool.class); + Constant instance = new Constant<>(true); + + Operand operand = instance.call(tf, tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -158,9 +157,10 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 11.); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Constant instance = new Constant<>(11.); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java index e9769806928..166011c3b64 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.VarianceScaling.Distribution; import org.tensorflow.framework.utils.TestSession; @@ -29,20 +29,6 @@ public class GlorotTest { private static final long SEED = 1000L; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public GlorotTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Glorot. */ @Test public void testCallNormalFloat() { @@ -51,9 +37,9 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Glorot instance = new Glorot<>(Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,8 +54,9 @@ public void testCallNormalDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,8 +69,9 @@ public void testCallUniformFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Glorot instance = new Glorot<>(Distribution.UNIFORM, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -97,8 +85,9 @@ public void testCallUniformDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.UNIFORM, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -109,9 +98,10 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -122,9 +112,10 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.UNIFORM, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -135,10 +126,10 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = - new Glorot<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.NORMAL, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java index 8953fa3005e..7b183358f85 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.VarianceScaling.Distribution; import org.tensorflow.framework.utils.TestSession; @@ -29,20 +29,6 @@ public class HeTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; int counter; - public HeTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class He. */ @Test public void testCallNormalFloat() { @@ -51,8 +37,9 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + He instance = new He<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -66,8 +53,9 @@ public void testCallNormalDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -80,8 +68,9 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + He instance = new He<>(Distribution.UNIFORM, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -95,8 +84,9 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.UNIFORM, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -107,9 +97,10 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -120,9 +111,10 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.UNIFORM, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -133,9 +125,10 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.NORMAL, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java index 6eee5473937..3f5c6cdb363 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java @@ -14,37 +14,19 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; /** Test the Identity initializer */ public class IdentityTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public IdentityTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Constant. */ @Test public void testCallFloat() { @@ -64,8 +46,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Identity instance = new Identity<>(tf, 2.); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Identity instance = new Identity<>(2.); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -90,8 +72,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Identity instance = new Identity<>(tf, 2.); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Identity instance = new Identity<>(2.); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -103,9 +85,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Identity instance = new Identity<>(tf, 2.); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Identity instance = new Identity<>(2.); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java index 336850a5549..8858bac13dd 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.VarianceScaling.Distribution; import org.tensorflow.framework.utils.TestSession; @@ -29,20 +29,6 @@ public class LeCunTest { private static final long SEED = 1000L; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public LeCunTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class LeCun. */ @Test public void testCallNormalFloat() { @@ -51,8 +37,8 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + LeCun instance = new LeCun<>(Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -66,8 +52,8 @@ public void testCallNormalDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -80,8 +66,8 @@ public void testCallUniformFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + LeCun instance = new LeCun<>(Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -95,8 +81,8 @@ public void testCallUniformDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -107,9 +93,9 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.TRUNCATED_NORMAL, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -120,9 +106,9 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -133,9 +119,9 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.NORMAL, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java index 053ba5dd7ff..4872ce7ad8e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java @@ -14,12 +14,18 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; @@ -29,20 +35,6 @@ public class OnesTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public OnesTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Ones. */ @Test public void testCallUInt() { @@ -51,8 +43,8 @@ public void testCallUInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TUint8.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -65,8 +57,8 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt32.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -79,8 +71,8 @@ public void testCallLong() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt64.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -93,8 +85,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -108,8 +100,8 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -125,8 +117,8 @@ public void testCallString() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - instance.call(tf.constant(shape), TString.class); + Ones instance = new Ones<>(); + instance.call(tf, tf.constant(shape), TString.class); fail("IllegalArgumentException should have been thrown for TString"); } }); @@ -140,8 +132,8 @@ public void testCallBool() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TBool.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -153,9 +145,23 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Ones instance = new Ones<>(); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); + session.evaluate(operand1, operand2); + } + } + + @Test + public void testFunctional() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(2, 2); + + Initializer instance = (ltf, dims, type) -> ltf.ones(dims, type); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java index 22b89d9177c..c933e669dfd 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java @@ -14,17 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; /** Test the Orthogonal initializer */ public class OrthogonalTest { @@ -33,20 +29,6 @@ public class OrthogonalTest { private static final double GAIN_VALUE = 1.0; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public OrthogonalTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Orthogonal. */ @Test public void testCallFloat() { @@ -156,8 +138,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Orthogonal instance = new Orthogonal<>(GAIN_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -271,8 +253,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Orthogonal instance = new Orthogonal<>(GAIN_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -284,9 +266,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Orthogonal instance = new Orthogonal<>(GAIN_VALUE, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java index 3b2b3bdb243..dada058af42 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -30,20 +30,6 @@ public class RandomNormalTest { private static final double STDDEV_VALUE = 3.0; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public RandomNormalTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class RandomNormal. */ @Test public void testCalltestSoftmaxFloat() { @@ -52,9 +38,8 @@ public void testCalltestSoftmaxFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = - new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + RandomNormal instance = new RandomNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,9 +53,8 @@ public void testCalltestSoftmaxDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = - new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + RandomNormal instance = new RandomNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,10 +66,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = - new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + RandomNormal instance = new RandomNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java index 23e26083a9b..1a1b3f755b7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -31,20 +31,6 @@ public class RandomUniformTest { private static final double MAX_VALUE = 10.0; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public RandomUniformTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class RandomUniform. */ @Test public void testCallInt() { @@ -53,9 +39,8 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TInt32.class); + RandomUniform instance = new RandomUniform<>(MIN_VALUE, MAX_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -68,9 +53,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + RandomUniform instance = new RandomUniform<>(MIN_VALUE, MAX_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -84,9 +68,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + RandomUniform instance = new RandomUniform<>(MIN_VALUE, MAX_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -98,10 +81,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + RandomUniform instance = new RandomUniform<>(MIN_VALUE, MAX_VALUE, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java index 96bf915e199..6ea19fde349 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -30,20 +30,6 @@ public class TruncatedNormalTest { private static final double STDDEV_VALUE = 3.0; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public TruncatedNormalTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class TruncatedNormal. */ @Test public void testCallFloat() { @@ -52,9 +38,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = - new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + TruncatedNormal instance = new TruncatedNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,9 +53,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = - new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + TruncatedNormal instance = new TruncatedNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,10 +66,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = - new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + TruncatedNormal instance = new TruncatedNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java index 159affb07e2..56aa95ecf73 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -28,20 +28,6 @@ public class VarianceScalingTest { private static final long SEED = 1000L; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public VarianceScalingTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class VarianceScaling. */ @Test public void testCallFloat1FanInTruncatedNormal() { @@ -52,12 +38,11 @@ public void testCallFloat1FanInTruncatedNormal() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -73,12 +58,11 @@ public void testCallDouble1FanInTruncatedNormal() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -93,12 +77,8 @@ public void testCallFloat1FanInNormal() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, - 1.0, - VarianceScaling.Mode.FAN_IN, - VarianceScaling.Distribution.NORMAL, - SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -114,12 +94,8 @@ public void testCalltestSoftmaxDouble1FanInNormal() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, - 1.0, - VarianceScaling.Mode.FAN_IN, - VarianceScaling.Distribution.NORMAL, - SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -134,8 +110,8 @@ public void testCalltestSoftmaxFloat1FanInUNIFORM() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -151,8 +127,8 @@ public void testCalltestSoftmaxDouble1FanInUNIFORM() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -166,9 +142,9 @@ public void testReproducible1() { VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -182,13 +158,9 @@ public void testReproducible2() { VarianceScaling instance = new VarianceScaling<>( - tf, - 1.0, - VarianceScaling.Mode.FAN_IN, - VarianceScaling.Distribution.NORMAL, - SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -202,13 +174,12 @@ public void testReproducible3() { VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_OUT, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -222,9 +193,9 @@ public void testReproducible4() { VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_AVG, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_AVG, VarianceScaling.Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java index 21bad6ff360..772baee1b61 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java @@ -14,32 +14,24 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; /** Test the Zeros initializer */ public class ZerosTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ZerosTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Zeros. */ @Test public void testCallUInt() { @@ -48,8 +40,8 @@ public void testCallUInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TUint8.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -62,8 +54,8 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt32.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -76,8 +68,8 @@ public void testCallLong() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt64.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -90,8 +82,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -105,8 +97,8 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -119,8 +111,8 @@ public void testCallString() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TString.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TString.class); session.evaluateString(operand, String::isEmpty); } } @@ -134,8 +126,8 @@ public void testCallBool() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TBool.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -147,9 +139,23 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Zeros instance = new Zeros<>(); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); + session.evaluate(operand1, operand2); + } + } + + @Test + public void testFunctional() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(2, 2); + + Initializer instance = (ltf, dims, type) -> ltf.zeros(dims, type); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java index d2128b80839..0b662414e8f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java @@ -32,11 +32,12 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); @@ -48,9 +49,9 @@ public void testAllCorrectUnweighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new BinaryCrossentropy(tf, true); + instance = new BinaryCrossentropy(true); - loss = instance.call(yTrue, logits); + loss = instance.call(tf, yTrue, logits); testSession.evaluate(expected, loss); } } @@ -67,7 +68,8 @@ public void testInvalidPredictionsRange() { catchClass, () -> { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}; float[] predArray = {2f, 1f, -1f, 0f}; Operand yTrue = @@ -75,7 +77,7 @@ public void testInvalidPredictionsRange() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -87,12 +89,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 1f, 0f}; float[] predArray = {1f, 1f, 1f, 0f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 3.83331f; testSession.evaluate(expected, loss); @@ -105,8 +108,9 @@ public void testUnweighted() { Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); - instance = new BinaryCrossentropy(tf, true); - loss = instance.call(yTrue1, logits); + instance = new BinaryCrossentropy(true); + + loss = instance.call(tf, yTrue1, logits); expected = 33.33333f; testSession.evaluate(expected, loss); } @@ -118,13 +122,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 1f, 0f}; float[] predArray = {1f, 1f, 1f, 0f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 8.816612f; testSession.evaluate(expected, loss); @@ -137,8 +142,9 @@ public void testScalarWeighted() { Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); - instance = new BinaryCrossentropy(tf, true); - loss = instance.call(yTrue1, logits, sampleWeight); + instance = new BinaryCrossentropy(true); + + loss = instance.call(tf, yTrue1, logits, sampleWeight); expected = 76.66667f; testSession.evaluate(expected, loss); } @@ -149,7 +155,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 1f, 0f}; float[] predArray = {1f, 1f, 1f, 0f}; float[] sampleWeightArray = {1.2f, 3.4f}; @@ -157,7 +164,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); Operand sampleWeight = tf.reshape(tf.constant(sampleWeightArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 4.59997f; testSession.evaluate(expected, loss); @@ -172,8 +179,9 @@ public void testSampleWeighted() { Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight1 = tf.constant(sampleWeightArray1); - instance = new BinaryCrossentropy(tf, true); - loss = instance.call(yTrue1, logits, sampleWeight1); + instance = new BinaryCrossentropy(true); + + loss = instance.call(tf, yTrue1, logits, sampleWeight1); expected = 100f; testSession.evaluate(expected, loss); } @@ -196,8 +204,9 @@ public void testNoReduction() { tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); BinaryCrossentropy instance = new BinaryCrossentropy( - tf, true, BinaryCrossentropy.LABEL_SMOOTHING_DEFAULT, Reduction.NONE); - Operand loss = instance.call(yTrue, logits); + true, BinaryCrossentropy.LABEL_SMOOTHING_DEFAULT, Reduction.NONE); + + Operand loss = instance.call(tf, yTrue, logits); Float[] expected = {0.f, 66.666664f}; testSession.evaluate(expected, loss); } @@ -215,8 +224,9 @@ public void testLabelSmoothing() { Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(1, 3))); - BinaryCrossentropy instance = new BinaryCrossentropy(tf, true, labelSmoothing); - Operand loss = instance.call(yTrue, logits); + BinaryCrossentropy instance = new BinaryCrossentropy(true, labelSmoothing); + + Operand loss = instance.call(tf, yTrue, logits); float expected = (100.0f + 50.0f * labelSmoothing) / 3.0f; testSession.evaluate(expected, loss); } catch (Exception expected) { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java index 13b287de3cd..3f6453b756a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java @@ -48,8 +48,9 @@ public void testAllCorrectUnweighted() { }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); - Operand loss = instance.call(yTrue, yPred); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0F; testSession.evaluate(expected, loss); @@ -62,8 +63,9 @@ public void testAllCorrectUnweighted() { yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new CategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits); + instance = new CategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits); testSession.setEpsilon(1e-3F); testSession.evaluate(0.0F, loss); } @@ -81,7 +83,8 @@ public void testInvalidPredictionsRange() { catchClass, () -> { Ops tf = testSession.getTF(); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + float[] trueArray = { 1L, 0L, 0L, 0L, 1L, 0L, @@ -97,7 +100,7 @@ public void testInvalidPredictionsRange() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -109,7 +112,8 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; float[] predArray = { .9F, .05F, .05F, @@ -118,7 +122,7 @@ public void testUnweighted() { }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.32396814F; testSession.evaluate(expected, loss); @@ -130,8 +134,9 @@ public void testUnweighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new CategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits); + instance = new CategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits); expected = 0.0573755F; testSession.evaluate(expected, loss); } @@ -158,8 +163,9 @@ public void testScalarWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.constant(2.3F); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = .7451267F; testSession.evaluate(expected, loss); @@ -171,8 +177,9 @@ public void testScalarWeighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new CategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits, sampleWeight); + instance = new CategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits, sampleWeight); expected = 0.13196386F; testSession.evaluate(expected, loss); } @@ -183,7 +190,8 @@ public void testSsampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + float[] sampeWeightArray = {1.2F, 3.4F, 5.6F}; int[] trueArray = { 1, 0, 0, @@ -199,7 +207,7 @@ public void testSsampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampeWeightArray), tf.constant(Shape.of(3, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.0696F; testSession.evaluate(expected, loss); @@ -211,8 +219,9 @@ public void testSsampleWeighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new CategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits, sampleWeight); + instance = new CategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits, sampleWeight); expected = 0.31829F; testSession.evaluate(expected, loss); } @@ -234,9 +243,9 @@ public void testNoReduction() { Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - CategoricalCrossentropy instance = - new CategoricalCrossentropy(tf, true, 0.0F, Reduction.NONE); - Operand loss = instance.call(yTrue, logits); + CategoricalCrossentropy instance = new CategoricalCrossentropy(true, 0.0F, Reduction.NONE); + + Operand loss = instance.call(tf, yTrue, logits); Float[] expected = {0.001822F, 0.000459F, 0.169846F}; testSession.evaluate(expected, loss); } @@ -254,8 +263,9 @@ public void testLabelSmoothing() { Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(1, 3))); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf, true, labelSmoothing); - Operand loss = instance.call(yTrue, logits); + CategoricalCrossentropy instance = new CategoricalCrossentropy(true, labelSmoothing); + + Operand loss = instance.call(tf, yTrue, logits); float expected = 400.0F * labelSmoothing / 3.0F; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java index b0d0442b3c7..d00f5374d61 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java @@ -31,12 +31,13 @@ public void testReductionNone() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf, Reduction.NONE); + CategoricalHinge instance = new CategoricalHinge(Reduction.NONE); + int[] trueArray = {1, 9, 2, -5}; float[] predArray = {4f, 8f, 12f, 8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); Float[] expected = {0.0f, 65.0f}; testSession.evaluate(expected, loss); } @@ -48,12 +49,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5}; float[] predArray = {4f, 8f, 12f, 8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 32.5f; testSession.evaluate(expected, loss); } @@ -65,17 +67,18 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 83.95f; testSession.evaluate(expected, loss); - Operand loss2 = instance.call(yTrue, yPred, sampleWeight); + Operand loss2 = instance.call(tf, yTrue, yPred, sampleWeight); testSession.evaluate(loss, loss2); } } @@ -85,7 +88,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] weightsNp = {1.2f, 3.4f}; @@ -93,7 +97,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(weightsNp), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 124.1f; testSession.evaluate(expected, loss); } @@ -104,13 +108,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -121,7 +126,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] weightsNp = {3, 6, 5, 0, 4, 2}; @@ -130,7 +136,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(weightsNp), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 4.0f; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java index 8350d1403ed..2f21929a969 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java @@ -33,11 +33,12 @@ public void testReductionNone() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - CosineSimilarity instance = new CosineSimilarity(tf, Reduction.NONE); + CosineSimilarity instance = new CosineSimilarity(Reduction.NONE); + Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); Float[] expected = {-0.720488f, 0.3460499f}; testSession.evaluate(expected, loss); } @@ -52,11 +53,12 @@ public void testUnweighted() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] expectedLoss = {0.720488f, -0.3460499f}; - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = -mean(expectedLoss); testSession.evaluate(expected, loss); } @@ -71,12 +73,13 @@ public void testScalarWeighted() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] expectedLoss = {0.720488f, -0.3460499f}; - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -mean(mul(expectedLoss, 2.3f)); testSession.evaluate(expected, loss); } @@ -90,14 +93,15 @@ public void testSampleWeighted() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] expectedLoss = {0.720488f, -0.3460499f}; - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + float[] weightsArray = {1.2f, 3.4f}; Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); Operand sampleWeight = tf.reshape(tf.constant(weightsArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -mean(mul(expectedLoss, weightsArray)); testSession.evaluate(expected, loss); } @@ -108,14 +112,15 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); Operand sampleWeight = tf.constant(0f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -128,14 +133,15 @@ public void testTimestepWeighted() { Ops tf = testSession.getTF(); float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + Shape shape = Shape.of(2, 3, 1); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); float[] weightsArray = {3, 6, 5, 0, 4, 2}; Operand sampleWeight = tf.reshape(tf.constant(weightsArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -2.0f; testSession.evaluate(expected, loss); } @@ -149,11 +155,12 @@ public void testAxis() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] expectedLoss = {0.720488f, -0.3460499f}; - CosineSimilarity instance = new CosineSimilarity(tf, 1); + CosineSimilarity instance = new CosineSimilarity(1); + Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = -mean(expectedLoss); testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java index 4770511207e..d5fe846c82e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java @@ -33,12 +33,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.50625f; testSession.evaluate(expected, loss); } @@ -56,14 +57,15 @@ public void testInvalidLabelValue() { catchClass, () -> { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] trueArray = {2f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -75,13 +77,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.164375f; testSession.evaluate(expected, loss); @@ -94,7 +97,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] sampleArray = {1.2f, 3.4f}; float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; @@ -102,7 +106,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.06125f; testSession.evaluate(expected, loss); } @@ -113,13 +117,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -130,7 +135,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf, Reduction.AUTO); + Hinge instance = new Hinge(Reduction.AUTO); + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f, 1f, 3f}; float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; @@ -140,7 +146,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 2.0125f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java index d1751f223a1..86a71e5ecbb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java @@ -32,8 +32,9 @@ public void testAllCorrect() { float[] trueArray = {.9f, .2f, .2f, .8f, .4f, .6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); - Operand loss = instance.call(yTrue, yTrue); + Huber instance = new Huber(); + + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -50,8 +51,9 @@ public void testUnweighted() { float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); - Operand loss = instance.call(yTrue, yPred); + Huber instance = new Huber(); + + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.10416666666666669f; testSession.evaluate(expected, loss); } @@ -67,9 +69,10 @@ public void testScalarWeighted() { float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); + Huber instance = new Huber(); + Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.23958333333333337f; testSession.evaluate(expected, loss); @@ -87,10 +90,11 @@ public void testSampleWeighted() { float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); + Huber instance = new Huber(); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.22766666666666668f; testSession.evaluate(expected, loss); } @@ -105,9 +109,10 @@ public void testZeroWeighted() { float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); + Huber instance = new Huber(); + Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -125,10 +130,11 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); - Huber instance = new Huber(tf); + Huber instance = new Huber(); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = .4025f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java index d57b61b18dd..1d7ee87b920 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java @@ -30,12 +30,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); + KLDivergence instance = new KLDivergence(); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.5960738398643668f; testSession.evaluate(expected, loss); } @@ -47,13 +48,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); + KLDivergence instance = new KLDivergence(); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.3709698316880434f; testSession.evaluate(expected, loss); } @@ -64,7 +66,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); + KLDivergence instance = new KLDivergence(); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; float[] sampleArray = {1.2f, 3.4f}; @@ -72,7 +75,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 2.0075711736936492f; testSession.evaluate(expected, loss); } @@ -83,13 +86,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); + KLDivergence instance = new KLDivergence(); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -100,7 +104,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf, Reduction.AUTO); + KLDivergence instance = new KLDivergence(Reduction.AUTO); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -110,7 +115,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.2495994912084345f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java index c4347b3fccb..ce6782cee3b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java @@ -30,12 +30,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); + LogCosh instance = new LogCosh(); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 4.829245330860459f; testSession.evaluate(expected, loss); } @@ -47,13 +48,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); + LogCosh instance = new LogCosh(); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 11.107264260979056f; testSession.evaluate(expected, loss); } @@ -64,7 +66,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); + LogCosh instance = new LogCosh(); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -72,7 +75,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 12.001114667519486f; testSession.evaluate(expected, loss); } @@ -83,13 +86,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); + LogCosh instance = new LogCosh(); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -100,7 +104,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf, Reduction.AUTO); + LogCosh instance = new LogCosh(Reduction.AUTO); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -110,7 +115,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 11.653484271934046f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java index 3498c6d53aa..cbcb2c37391 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java @@ -31,10 +31,11 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -46,12 +47,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 5.5f; testSession.evaluate(expected, loss); } @@ -63,13 +65,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 12.65f; testSession.evaluate(expected, loss); } @@ -80,7 +83,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -88,7 +92,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 81.4f / 6f; testSession.evaluate(expected, loss); } @@ -99,13 +103,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -116,7 +121,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf, Reduction.AUTO); + MeanAbsoluteError instance = new MeanAbsoluteError(Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -126,7 +132,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 83f / 6f; testSession.evaluate(expected, loss); @@ -141,7 +147,8 @@ public void testInvalidSampleWeight() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f}; @@ -151,7 +158,7 @@ public void testInvalidSampleWeight() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 83f / 6f; testSession.evaluate(expected, loss); } @@ -163,13 +170,14 @@ public void testNoReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf, Reduction.NONE); + MeanAbsoluteError instance = new MeanAbsoluteError(Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {10.733333f, 14.566667f}; testSession.evaluate(expected, loss); } @@ -180,13 +188,14 @@ public void testSumReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf, Reduction.SUM); + MeanAbsoluteError instance = new MeanAbsoluteError(Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {25.29999f}; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java index 7816a8a288a..b521f2f5644 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java @@ -30,10 +30,11 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -45,12 +46,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 211.85184f; testSession.evaluate(expected, loss); } @@ -62,13 +64,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 487.25922f; testSession.evaluate(expected, loss); } @@ -79,7 +82,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -87,7 +91,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 422.8889f; testSession.evaluate(expected, loss); } @@ -98,13 +102,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -115,7 +120,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf, Reduction.AUTO); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -125,7 +131,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 694.4445f; testSession.evaluate(expected, loss); } @@ -136,13 +142,14 @@ public void testNoReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf, Reduction.NONE); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {621.8518f, 352.66666f}; testSession.evaluate(expected, loss); } @@ -153,13 +160,14 @@ public void testSumReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf, Reduction.SUM); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 974.51843f; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java index 1a971f0492b..e9fd0d7e349 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java @@ -31,10 +31,11 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -46,12 +47,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 49.5f; testSession.evaluate(expected, loss); } @@ -63,13 +65,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 113.85f; testSession.evaluate(expected, loss); } @@ -80,7 +83,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -88,7 +92,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 127.96667f; testSession.evaluate(expected, loss); } @@ -99,13 +103,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -116,7 +121,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf, Reduction.AUTO); + MeanSquaredError instance = new MeanSquaredError(Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -126,7 +132,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 97.833336f; testSession.evaluate(expected, loss); @@ -141,7 +147,8 @@ public void testInvalidSampleWeight() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f}; @@ -151,7 +158,7 @@ public void testInvalidSampleWeight() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 173.25f; testSession.evaluate(expected, loss); } @@ -163,13 +170,14 @@ public void testNoReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf, Reduction.NONE); + MeanSquaredError instance = new MeanSquaredError(Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {84.333336f, 143.36665f}; testSession.evaluate(expected, loss); } @@ -180,13 +188,14 @@ public void testSumReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf, Reduction.SUM); + MeanSquaredError instance = new MeanSquaredError(Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {227.69998f}; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java index 558f9c84659..0c6d411c53f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java @@ -31,10 +31,11 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -46,12 +47,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 1.4370421f; testSession.evaluate(expected, loss); } @@ -63,13 +65,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 3.3051968f; testSession.evaluate(expected, loss); } @@ -80,7 +83,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -88,7 +92,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 3.7856376f; testSession.evaluate(expected, loss); } @@ -99,13 +103,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -116,7 +121,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.AUTO); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -126,7 +132,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 2.647374f; testSession.evaluate(expected, loss); @@ -141,7 +147,8 @@ public void testInvalidSampleWeight() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f}; @@ -151,7 +158,7 @@ public void testInvalidSampleWeight() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 83f / 6f; testSession.evaluate(expected, loss); } @@ -163,13 +170,14 @@ public void testNoReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.NONE); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {2.3006392f, 4.3097544f}; testSession.evaluate(expected, loss); } @@ -180,13 +188,14 @@ public void testSumReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.SUM); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {6.6103935f}; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java index 55c59ca5ac6..c354c83bfe2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java @@ -30,12 +30,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); + Poisson instance = new Poisson(); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = -3.306581945521002f; testSession.evaluate(expected, loss); } @@ -47,13 +48,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); + Poisson instance = new Poisson(); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -7.605138474698304f; testSession.evaluate(expected, loss); } @@ -64,7 +66,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); + Poisson instance = new Poisson(); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -72,7 +75,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -6.147338926788071f; testSession.evaluate(expected, loss); } @@ -83,13 +86,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); + Poisson instance = new Poisson(); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -100,7 +104,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf, Reduction.AUTO); + Poisson instance = new Poisson(Reduction.AUTO); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -110,7 +115,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -12.263126013890561f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java index a6a0ff35c78..113b89b82ff 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java @@ -44,8 +44,9 @@ public void testAllCorrectUnweighted() { }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); - Operand loss = instance.call(yTrue, yPred); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.0f; testSession.evaluate(expected, loss); @@ -57,8 +58,9 @@ public void testAllCorrectUnweighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new SparseCategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits); + instance = new SparseCategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits); testSession.evaluate(0.0f, loss); } } @@ -75,7 +77,8 @@ public void testInvalidPredictionsRange() { catchClass, () -> { Ops tf = testSession.getTF(); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + int[] trueArray = {0, 1, 2}; float[] predArray = { 1.9f, .05f, .05f, @@ -86,7 +89,7 @@ public void testInvalidPredictionsRange() { tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -98,7 +101,8 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + int[] trueArray = {0, 1, 2}; float[] predArray = { .9f, .05f, .05f, @@ -107,7 +111,7 @@ public void testUnweighted() { }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.32396814f; testSession.evaluate(expected, loss); @@ -119,8 +123,9 @@ public void testUnweighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new SparseCategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits); + instance = new SparseCategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits); expected = 0.05737559f; testSession.evaluate(expected, loss); } @@ -143,8 +148,9 @@ public void testScalarWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.constant(2.3f); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = .7451267f; testSession.evaluate(expected, loss); @@ -156,8 +162,9 @@ public void testScalarWeighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new SparseCategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits, sampleWeight); + instance = new SparseCategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits, sampleWeight); expected = 0.13196386f; testSession.evaluate(expected, loss); } @@ -168,7 +175,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + float[] sampleWeightArray = {1.2f, 3.4f, 5.6f}; int[] trueArray = {0, 1, 2}; float[] predArray = { @@ -180,7 +188,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleWeightArray), tf.constant(Shape.of(3, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.0696f; testSession.evaluate(expected, loss); @@ -192,8 +200,9 @@ public void testSampleWeighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new SparseCategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits, sampleWeight); + instance = new SparseCategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits, sampleWeight); expected = 0.31829f; testSession.evaluate(expected, loss); } @@ -216,8 +225,9 @@ public void testNoReduction() { Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); SparseCategoricalCrossentropy instance = - new SparseCategoricalCrossentropy(tf, true, Reduction.NONE); - Operand loss = instance.call(yTrue, logits); + new SparseCategoricalCrossentropy(true, Reduction.NONE); + + Operand loss = instance.call(tf, yTrue, logits); Float[] expected = {0.001822f, 0.000459f, 0.169846f}; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java index 57a012bbe9d..979e778e4c3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java @@ -32,12 +32,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.364062f; testSession.evaluate(expected, loss); } @@ -55,14 +56,15 @@ public void testInvalidLabelValue() { catchClass, () -> { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] trueArray = {0, 2, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -74,13 +76,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.8373437f; testSession.evaluate(expected, loss); } @@ -91,7 +94,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] sampleArray = {1.2f, 3.4f}; float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; @@ -99,7 +103,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.7043125f; testSession.evaluate(expected, loss); } @@ -110,13 +114,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -127,7 +132,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf, Reduction.AUTO); + SquaredHinge instance = new SquaredHinge(Reduction.AUTO); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = @@ -137,7 +143,7 @@ public void testTimestepWeighted() { float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f, 1f, 3f}; Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.54250000f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index d6786b71972..d957cfb2508 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -1,13 +1,17 @@ package org.tensorflow.framework.optimizers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.framework.initializers.Glorot; import org.tensorflow.framework.initializers.VarianceScaling; import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.op.Op; @@ -26,10 +30,8 @@ import org.tensorflow.types.family.TType; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; /** Test cases for GradientDescent Optimizer */ @@ -136,14 +138,14 @@ public void testDeterminism() { Ops tf = Ops.create(g); Glorot initializer = - new Glorot<>(tf, VarianceScaling.Distribution.TRUNCATED_NORMAL, 1L); + new Glorot<>(VarianceScaling.Distribution.TRUNCATED_NORMAL, 1L); // Inputs Placeholder input = tf.withName("input").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 20))); // Fully connected layer Variable fcWeights = - tf.variable(initializer.call(tf.array(20L, 200L), TFloat32.class)); + tf.variable(initializer.call(tf, tf.array(20L, 200L), TFloat32.class)); fcWeightName = fcWeights.op().name(); Variable fcBiases = tf.variable(tf.fill(tf.array(200), tf.constant(0.1f))); fcBiasName = fcBiases.op().name(); @@ -151,13 +153,13 @@ public void testDeterminism() { // Output layer Variable outputWeights = - tf.variable(initializer.call(tf.array(200L, 2L), TFloat32.class)); + tf.variable(initializer.call(tf, tf.array(200L, 2L), TFloat32.class)); outputWeightName = outputWeights.op().name(); Variable outputBiases = tf.variable(tf.fill(tf.array(2L), tf.constant(0.1f))); outputBiasName = outputBiases.op().name(); Add output = tf.math.add(tf.linalg.matMul(relu, outputWeights), outputBiases); - // Loss + // AbstractLoss Placeholder placeholder = tf.withName("output").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 2))); Mean loss = @@ -205,12 +207,15 @@ public void testDeterminism() { .fetch(outputBiasName) .run()); - TFloat32 lossVal = (TFloat32) s.runner() - .addTarget(trainName) - .feed("input", dataTensor) - .feed("output", targetTensor) - .fetch(lossName) - .run().get(0); + TFloat32 lossVal = + (TFloat32) + s.runner() + .addTarget(trainName) + .feed("input", dataTensor) + .feed("output", targetTensor) + .fetch(lossName) + .run() + .get(0); initialLoss[i] = lossVal.getFloat(); lossVal.close(); @@ -222,12 +227,15 @@ public void testDeterminism() { .fetch(outputBiasName) .run()); - lossVal = (TFloat32) s.runner() - .addTarget(trainName) - .feed("input", dataTensor) - .feed("output", targetTensor) - .fetch(lossName) - .run().get(0); + lossVal = + (TFloat32) + s.runner() + .addTarget(trainName) + .feed("input", dataTensor) + .feed("output", targetTensor) + .fetch(lossName) + .run() + .get(0); postTrainingLoss[i] = lossVal.getFloat(); lossVal.close(); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java index 181ae367f07..a4b98c002cb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -17,25 +17,25 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.2f, 0.3f); + L1L2 instance = new L1L2(0.2f, 0.3f); assertEquals(0.2f, instance.getL1()); assertEquals(0.3f, instance.getL2()); - instance = new L1L2(tf, 0, 0); + instance = new L1L2(0, 0); assertEquals(0.f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2(tf, 0.5f, 0); + instance = new L1L2(0.5f, 0); assertEquals(0.5f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2(tf, 0, 0.5f); + instance = new L1L2(0, 0.5f); assertEquals(0.f, instance.getL1()); assertEquals(0.5f, instance.getL2()); - instance = new L1L2(tf); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); + instance = new L1L2(); + assertEquals(AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); } } @@ -44,8 +44,8 @@ public void testCallDefaultsConstant() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf); - Operand result = instance.call(tf.constant(555f)); + L1L2 instance = new L1L2(); + Operand result = instance.call(tf, tf.constant(555f)); session.evaluate(3085.8f, result); } } @@ -55,10 +55,10 @@ public void testCallL1L2_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0, 0); + L1L2 instance = new L1L2(0, 0); Operand weights = tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); session.evaluate(0, result); } } @@ -68,10 +68,10 @@ public void testCallL1L2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.01f, 0.02f); + L1L2 instance = new L1L2(0.01f, 0.02f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); float expected = regularizeL1L2(w, 0.01f, 0.02f); session.setEpsilon(.09f); session.evaluate(expected, result); @@ -83,10 +83,10 @@ public void testCallL1L2TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.01f, 0.02f); + L1L2 instance = new L1L2(0.01f, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); double expected = regularizeL1L2(w, 0.01f, 0.02f); session.setEpsilon(.09f); session.evaluate(expected, result); @@ -98,10 +98,10 @@ public void testCallL2_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.01f, 0); + L1L2 instance = new L1L2(0.01f, 0); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); float expected = regularizeL1(w, 0.01f); session.evaluate(expected, result); } @@ -112,10 +112,10 @@ public void testCallL1_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0, 0.02f); + L1L2 instance = new L1L2(0, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); double expected = regularizeL2(w, 0.02f); session.setEpsilon(.01f); session.evaluate(expected, result); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java index 0e42a257816..f7d540fb8e1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java @@ -17,16 +17,16 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1(tf, 0.2f); + L1 instance = new L1(0.2f); assertEquals(0.2f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1(tf, 0f); + instance = new L1(0f); assertEquals(0.f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1(tf); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + instance = new L1(); + assertEquals(AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); assertEquals(0.f, instance.getL2()); } } @@ -36,10 +36,10 @@ public void testCallL10() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1(tf, 0.0f); + L1 instance = new L1(0.0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); session.evaluate(0f, result); } } @@ -49,11 +49,11 @@ public void testCallL1TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1(tf); + L1 instance = new L1(); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = regularizeL1(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + Operand result = instance.call(tf, weights); + float expected = regularizeL1(w, AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY); session.evaluate(expected, result); } } @@ -63,10 +63,10 @@ public void testCallL1TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1(tf, 0.02f); + L1 instance = new L1(0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); double expected = regularizeL1(w, 0.02f); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java index aba036ee306..4579ccaf551 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java @@ -17,16 +17,16 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2(tf, 0.2f); + L2 instance = new L2(0.2f); assertEquals(0.2f, instance.getL2()); assertEquals(0.f, instance.getL1()); - instance = new L2(tf, 0f); + instance = new L2(0f); assertEquals(0.f, instance.getL2()); assertEquals(0.f, instance.getL1()); - L2 instance64 = new L2(tf); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2()); + L2 instance64 = new L2(); + assertEquals(AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2()); assertEquals(0.f, instance64.getL1()); } } @@ -36,10 +36,10 @@ public void testCallL20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2(tf, 0.0f); + L2 instance = new L2(0.0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); session.evaluate(0, result); } } @@ -49,11 +49,11 @@ public void testCallL2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2(tf); + L2 instance = new L2(); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = regularizeL2(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + Operand result = instance.call(tf, weights); + float expected = regularizeL2(w, AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY); session.setEpsilon(.01f); session.evaluate(expected, result); } @@ -64,10 +64,10 @@ public void testCallL2TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2(tf, 0.02f); + L2 instance = new L2(0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); double expected = regularizeL2(w, 0.02f); session.setEpsilon(.01f); session.evaluate(expected, result); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java index fe2624cec3d..6918f631e6a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java @@ -14,13 +14,13 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 regularizer = new L1L2(tf, 0.01f, 0f); + L1L2 regularizer = new L1L2(0.01f, 0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand regularizerResult = regularizer.call(weights); - RegularizerLoss lossInstance = new RegularizerLoss(tf, regularizer); + Operand regularizerResult = regularizer.call(tf, weights); + RegularizerLoss lossInstance = new RegularizerLoss(regularizer); - Operand loss = lossInstance.call(null, null, weights); + Operand loss = lossInstance.call(tf, null, null, weights); session.evaluate(regularizerResult, loss); } } From b6ae875512716e205b5b7ac9164965773f8dbd9c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 1 Jun 2021 11:18:31 -0400 Subject: [PATCH 2/5] Move Ops from CTOR to call method --- .../activations/AbstractActivation.java | 46 ++++++ .../framework/activations/Activation.java | 45 +----- .../tensorflow/framework/activations/ELU.java | 34 ++--- .../framework/activations/Exponential.java | 23 +-- .../framework/activations/HardSigmoid.java | 31 ++-- .../framework/activations/Linear.java | 18 +-- .../framework/activations/ReLU.java | 30 ++-- .../framework/activations/SELU.java | 21 +-- .../framework/activations/Sigmoid.java | 21 +-- .../framework/activations/Softmax.java | 22 +-- .../framework/activations/Softplus.java | 21 +-- .../framework/activations/Softsign.java | 21 +-- .../framework/activations/Swish.java | 11 +- .../framework/activations/Tanh.java | 14 +- .../constraints/AbstractConstraint.java | 89 ++++++++++++ .../framework/constraints/Constraint.java | 88 +----------- .../framework/constraints/MaxNorm.java | 30 ++-- .../framework/constraints/MinMaxNorm.java | 30 ++-- .../framework/constraints/NonNeg.java | 15 +- .../framework/constraints/UnitNorm.java | 31 ++-- .../initializers/BaseInitializer.java | 21 ++- .../framework/initializers/Constant.java | 31 ++-- .../framework/initializers/Glorot.java | 12 +- .../tensorflow/framework/initializers/He.java | 16 +-- .../framework/initializers/Identity.java | 30 ++-- .../framework/initializers/Initializer.java | 7 +- .../framework/initializers/LeCun.java | 15 +- .../framework/initializers/Ones.java | 20 +-- .../framework/initializers/Orthogonal.java | 21 +-- .../framework/initializers/RandomNormal.java | 26 ++-- .../framework/initializers/RandomUniform.java | 31 ++-- .../initializers/TruncatedNormal.java | 23 +-- .../initializers/VarianceScaling.java | 32 ++--- .../framework/initializers/Zeros.java | 17 +-- .../framework/losses/BinaryCrossentropy.java | 79 +++++----- .../losses/CategoricalCrossentropy.java | 135 ++++++++---------- .../framework/losses/CategoricalHinge.java | 40 +++--- .../framework/losses/CosineSimilarity.java | 115 +++++++-------- .../tensorflow/framework/losses/Hinge.java | 48 +++---- .../tensorflow/framework/losses/Huber.java | 61 ++++---- .../framework/losses/KLDivergence.java | 50 ++++--- .../tensorflow/framework/losses/LogCosh.java | 54 ++++--- .../org/tensorflow/framework/losses/Loss.java | 78 +--------- .../framework/losses/MeanAbsoluteError.java | 44 +++--- .../losses/MeanAbsolutePercentageError.java | 45 +++--- .../framework/losses/MeanSquaredError.java | 44 +++--- .../losses/MeanSquaredLogarithmicError.java | 44 +++--- .../tensorflow/framework/losses/Poisson.java | 54 ++++--- .../framework/losses/Reduction.java | 2 +- .../losses/SparseCategoricalCrossentropy.java | 73 +++++----- .../framework/losses/SquaredHinge.java | 53 ++++--- .../framework/losses/impl/AbstractLoss.java | 89 ++++++++++++ .../org/tensorflow/framework/metrics/AUC.java | 95 ++++++------ .../framework/metrics/Accuracy.java | 8 +- .../framework/metrics/BinaryAccuracy.java | 8 +- .../metrics/CategoricalAccuracy.java | 19 ++- .../metrics/CategoricalCrossentropy.java | 20 ++- .../framework/metrics/FalseNegatives.java | 42 +++--- .../framework/metrics/FalsePositives.java | 42 +++--- .../tensorflow/framework/metrics/MeanIoU.java | 14 +- .../framework/metrics/MeanRelativeError.java | 11 +- .../framework/metrics/MeanTensor.java | 4 +- .../framework/metrics/Precision.java | 71 +++++---- .../framework/metrics/PrecisionAtRecall.java | 7 +- .../tensorflow/framework/metrics/Recall.java | 26 ++-- .../framework/metrics/RecallAtPrecision.java | 4 +- .../metrics/RootMeanSquaredError.java | 3 +- .../metrics/SensitivityAtSpecificity.java | 20 +-- .../metrics/SparseCategoricalAccuracy.java | 6 +- .../metrics/SpecificityAtSensitivity.java | 20 +-- .../org/tensorflow/framework/metrics/Sum.java | 8 +- .../metrics/TopKCategoricalAccuracy.java | 4 +- .../framework/metrics/TrueNegatives.java | 42 +++--- .../framework/metrics/TruePositives.java | 42 +++--- .../impl/ConfusionMatrixConditionCount.java | 26 ++-- .../framework/metrics/impl/LossMetric.java | 2 +- .../metrics/impl/MeanMetricWrapper.java | 8 +- .../framework/metrics/impl/MetricsHelper.java | 116 +++++++-------- .../impl/SensitivitySpecificityBase.java | 6 +- .../framework/metrics/impl/SetsOps.java | 24 ++-- .../framework/metrics/impl/SymbolicShape.java | 45 +++++- .../metrics/impl/WeightsBroadcastOps.java | 34 ++--- .../regularizers/AbstractRegularizer.java | 63 ++++++++ .../tensorflow/framework/regularizers/L1.java | 33 +++-- .../framework/regularizers/L1L2.java | 38 ++--- .../tensorflow/framework/regularizers/L2.java | 33 +++-- .../framework/regularizers/Regularizer.java | 67 +-------- .../regularizers/RegularizerLoss.java | 31 ++-- .../framework/activations/ELUTest.java | 33 +---- .../activations/ExponentialTest.java | 28 +--- .../activations/HardSigmoidTest.java | 28 +--- .../framework/activations/LinearTest.java | 28 +--- .../framework/activations/ReLUTest.java | 58 ++++---- .../framework/activations/SELUTest.java | 28 +--- .../framework/activations/SigmoidTest.java | 27 +--- .../framework/activations/SoftmaxTest.java | 47 ++---- .../framework/activations/SoftplusTest.java | 24 +--- .../framework/activations/SoftsignTest.java | 24 +--- .../framework/activations/SwishTest.java | 28 +--- .../framework/activations/TanhTest.java | 24 +--- .../framework/constraints/MaxNormTest.java | 8 +- .../framework/constraints/MinMaxNormTest.java | 4 +- .../framework/constraints/NonNegTest.java | 8 +- .../framework/constraints/UnitNormTest.java | 8 +- .../framework/initializers/ConstantTest.java | 66 ++++----- .../framework/initializers/GlorotTest.java | 57 ++++---- .../framework/initializers/HeTest.java | 57 ++++---- .../framework/initializers/IdentityTest.java | 34 ++--- .../framework/initializers/LeCunTest.java | 50 +++---- .../framework/initializers/OnesTest.java | 72 +++++----- .../initializers/OrthogonalTest.java | 34 ++--- .../initializers/RandomNormalTest.java | 33 ++--- .../initializers/RandomUniformTest.java | 38 ++--- .../initializers/TruncatedNormalTest.java | 33 ++--- .../initializers/VarianceScalingTest.java | 73 +++------- .../framework/initializers/ZerosTest.java | 72 +++++----- .../losses/BinaryCrossentropyTest.java | 54 ++++--- .../losses/CategoricalCrossentropyTest.java | 56 +++++--- .../losses/CategoricalHingeTest.java | 32 +++-- .../losses/CosineSimilarityTest.java | 35 +++-- .../framework/losses/HingeTest.java | 30 ++-- .../framework/losses/HuberTest.java | 30 ++-- .../framework/losses/KLDivergenceTest.java | 25 ++-- .../framework/losses/LogCoshTest.java | 25 ++-- .../losses/MeanAbsoluteErrorTest.java | 45 +++--- .../MeanAbsolutePercentageErrorTest.java | 40 +++--- .../losses/MeanSquaredErrorTest.java | 45 +++--- .../MeanSquaredLogarithmicErrorTest.java | 45 +++--- .../framework/losses/PoissonTest.java | 25 ++-- .../SparseCategoricalCrossentropyTest.java | 50 ++++--- .../framework/losses/SquaredHingeTest.java | 30 ++-- .../optimizers/GradientDescentTest.java | 48 ++++--- .../framework/regularizers/L1L2Test.java | 38 ++--- .../framework/regularizers/L1Test.java | 22 +-- .../framework/regularizers/L2Test.java | 22 +-- .../regularizers/RegularizerLossTest.java | 8 +- 136 files changed, 2278 insertions(+), 2544 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/activations/AbstractActivation.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/AbstractLoss.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/AbstractRegularizer.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/AbstractActivation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/AbstractActivation.java new file mode 100644 index 00000000000..335b8697273 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/AbstractActivation.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Abstract base class for Activations */ +public abstract class AbstractActivation implements Activation { + + /** The TensorFlow Ops */ + protected Ops tf; + + /** Creates the abstract class for an AbstractActivation */ + protected AbstractActivation() {} + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + protected Ops getTF() { + return this.tf; + } + + /** + * Sets the TensorFlow Ops + * + * @param tf the TensorFlow Ops + */ + protected void setTF(Ops tf) { + this.tf = tf; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java index e1482a51a8a..f73c6678ab3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,50 +19,19 @@ import org.tensorflow.types.family.TNumber; /** - * Abstract base class for Activations + * Interface for Activations * - *

    Note: The {@link #tf} attribute must be set prior to invoking the call method. See - * {@link #setTF(Ops)} and the constructor {@link #Activation(Ops)}. - * - * @param the data type of the activation + * @param the data type of the input and the result */ -public abstract class Activation { - - /** The TensorFlow Ops */ - protected Ops tf; - - /** - * Creates the abstract class for an Activation - * - * @param tf the TensorFlow Ops - */ - protected Activation(Ops tf) { - this.tf = tf; - } - - /** - * Sets the TensorFlow Ops - * - * @param tf the TensorFlow Ops - */ - protected void setTF(Ops tf) { - this.tf = tf; - } - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - protected Ops getTF() { - return this.tf; - } +@FunctionalInterface +public interface Activation { /** * Gets the calculation operation for the activation. * + * @param tf the TensorFlow Ops * @param input the input tensor * @return The operand for the activation */ - public abstract Operand call(Operand input); + Operand call(Ops tf, Operand input); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java index 2f2f16f2752..919a947a127 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java @@ -19,6 +19,8 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Exponential linear unit. * @@ -44,53 +46,41 @@ * Operand<TFloat32> result = elu.call(input); * * - * @param the data type of the activation * @see Clevert et al, 2016, Fast and Accurate Deep * Network Learning by Exponential Linear Units (ELUs) */ -public class ELU extends Activation { +public class ELU extends AbstractActivation { private static final double ALPHA_DEFAULT = 1.0; /** A scalar, slope of negative section. */ private final double alpha; - /** - * Creates a new ELU with alpha={@link #ALPHA_DEFAULT}. - * - * @param tf the TensorFlow Ops - */ - public ELU(Ops tf) { - this(tf, ALPHA_DEFAULT); + /** Creates a new ELU with alpha={@link #ALPHA_DEFAULT}. */ + public ELU() { + this(ALPHA_DEFAULT); } /** * Creates a new ELU * - * @param tf the TensorFlow Ops * @param alpha A scalar, slope of negative section. It controls the value to which an ELU * saturates for negative net inputs. */ - public ELU(Ops tf, double alpha) { - super(tf); + public ELU(double alpha) { + super(); this.alpha = alpha; } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { - + public Operand call(Ops tf, Operand input) { Operand result = tf.nn.elu(input); if (alpha == 1.0) return result; else { Class inputType = input.type(); - Operand y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), inputType)); - Operand cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), inputType)); + Operand y = tf.math.mul(result, cast(tf, tf.constant(alpha), inputType)); + Operand cond = tf.math.greater(result, cast(tf, tf.constant(0), inputType)); return tf.select(cond, result, y); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java index d5fdff36c61..8398ada6362 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java @@ -30,28 +30,17 @@ * Operand<TFloat32> result = exp.call(input); * // result is [0.04978707f, 0.36787945f, 1.f, 2.7182817f, 20.085537f] * - * - * @param the data type of the activation */ -public class Exponential extends Activation { +public class Exponential extends AbstractActivation { - /** - * Creates an Exponential activation. - * - * @param tf the TensorFlow Ops - */ - public Exponential(Ops tf) { - super(tf); + /** Creates an Exponential activation. */ + public Exponential() { + super(); } - /** - * Calculates the Exponential activation. - * - * @param input the input tensor - * @return an Operand for the exponential activation: exp(x). - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.math.exp(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java index 0b7cf573b8e..fac4d14eca5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java @@ -18,6 +18,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Hard sigmoid activation. * @@ -40,34 +42,23 @@ * Operand<TFloat32> result = hardSigmoid.call(input); * // result is [0.f , 0.3f, 0.5f, 0.7f, 1.f] * - * - * @param the data type of the result */ -public class HardSigmoid extends Activation { +public class HardSigmoid extends AbstractActivation { - /** - * Creates Hard sigmoid activation. - * - * @param tf the TensorFlow Ops - */ - public HardSigmoid(Ops tf) { - super(tf); + /** Creates Hard sigmoid activation. */ + public HardSigmoid() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { Class inputType = input.type(); - Operand point2 = tf.dtypes.cast(tf.constant(0.2), inputType); - Operand point5 = tf.dtypes.cast(tf.constant(0.5), inputType); + Operand point2 = cast(tf, tf.constant(0.2), inputType); + Operand point5 = cast(tf, tf.constant(0.5), inputType); Operand x = tf.math.add(tf.math.mul(input, point2), point5); return tf.clipByValue( - x, tf.dtypes.cast(tf.constant(0), inputType), tf.dtypes.cast(tf.constant(1), inputType)); + x, cast(tf, tf.constant(0), inputType), cast(tf, tf.constant(1), inputType)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java index d907397995d..d1a5eede616 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java @@ -19,9 +19,9 @@ import org.tensorflow.types.family.TNumber; /** - * Linear activation function (pass-through). + * Linear activation function (pass-through). * - *

    The linear activation returns its input. It is also known as the Identity activation function.

    + *

    The linear activation returns its input. It is also known as the Identity activation function. * *

    For example: * @@ -33,20 +33,16 @@ * // result is [-3.0f,-1.0f, 0.0f,1.0f,3.0f] * */ -public class Linear extends Activation { +public class Linear extends AbstractActivation { - /** - * Creates a linear activation. - * - * @param tf the TensorFlow Ops - */ - public Linear(Ops tf) { - super(tf); + /** Creates a linear activation. */ + public Linear() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return input; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java index aef6ebf2992..c966e5d9ddd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java @@ -20,6 +20,8 @@ import org.tensorflow.op.nn.LeakyRelu; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Rectified Linear Unit(ReLU) activation. * @@ -58,7 +60,7 @@ * * @param the data type of the result */ -public class ReLU extends Activation { +public class ReLU extends AbstractActivation { public static final float ALPHA_DEFAULT = 0.0f; public static final float MAX_VALUE_DEFAULT = Float.NaN; @@ -71,24 +73,21 @@ public class ReLU extends Activation { /** * Creates a new ReLU with alpha={@link #ALPHA_DEFAULT}, maxValue={@link #MAX_VALUE_DEFAULT}, * threshold={@link #THRESHOLD_DEFAULT}, - * - * @param tf the TensorFlow Ops */ - public ReLU(Ops tf) { - this(tf, ALPHA_DEFAULT, MAX_VALUE_DEFAULT, THRESHOLD_DEFAULT); + public ReLU() { + this(ALPHA_DEFAULT, MAX_VALUE_DEFAULT, THRESHOLD_DEFAULT); } /** * Creates a new ReLU * - * @param tf the TensorFlow Ops * @param alpha governs the slope for values lower than the threshold. * @param maxValue sets the saturation threshold (the largest value the function will return). * @param threshold the threshold value of the activation function below which values will be * damped or set to zero. */ - public ReLU(Ops tf, float alpha, float maxValue, float threshold) { - super(tf); + public ReLU(float alpha, float maxValue, float threshold) { + super(); this.alpha = alpha; this.maxValue = maxValue; this.threshold = threshold; @@ -96,7 +95,7 @@ public ReLU(Ops tf, float alpha, float maxValue, float threshold) { /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { Class inputType = input.type(); boolean clipMax = !Float.isNaN(maxValue); @@ -108,7 +107,7 @@ public Operand call(Operand input) { if (threshold != 0) { negativePart = tf.nn.relu( - tf.math.add(tf.math.neg(input), tf.dtypes.cast(tf.constant(threshold), inputType))); + tf.math.add(tf.math.neg(input), cast(tf, tf.constant(threshold), inputType))); } else { negativePart = tf.nn.relu(tf.math.neg(input)); } @@ -117,8 +116,8 @@ public Operand call(Operand input) { Operand lInput; if (threshold != 0) { // computes input for input > threshold else 0 - Greater greater = tf.math.greater(input, tf.dtypes.cast(tf.constant(threshold), inputType)); - lInput = tf.math.mul(input, tf.dtypes.cast(greater, inputType)); + Greater greater = tf.math.greater(input, cast(tf, tf.constant(threshold), inputType)); + lInput = tf.math.mul(input, cast(tf, greater, inputType)); } else if (maxValue == 6) { // if no threshold, then can use nn.relu6 native TF op for performance lInput = tf.nn.relu6(input); @@ -127,15 +126,14 @@ public Operand call(Operand input) { lInput = tf.nn.relu(input); } if (clipMax) { - Operand lmaxValue = tf.dtypes.cast(tf.constant(maxValue), inputType); - Operand zero = tf.dtypes.cast(tf.constant(0), inputType); + Operand lmaxValue = cast(tf, tf.constant(maxValue), inputType); + Operand zero = cast(tf, tf.constant(0), inputType); lInput = tf.clipByValue(lInput, zero, lmaxValue); } if (alpha != 0.) { lInput = - tf.math.sub( - lInput, tf.math.mul(tf.dtypes.cast(tf.constant(alpha), inputType), negativePart)); + tf.math.sub(lInput, tf.math.mul(cast(tf, tf.constant(alpha), inputType), negativePart)); } return lInput; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java index f24731049fb..a28052486e5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java @@ -45,25 +45,16 @@ * @param the data type of the activation * @see Klambauer et al., 2017 */ -public class SELU extends Activation { +public class SELU extends AbstractActivation { - /** - * Creates a Scaled Exponential Linear Unit (SELU) activation. - * - * @param tf the TensorFlow Ops - */ - public SELU(Ops tf) { - super(tf); + /** Creates a Scaled Exponential Linear Unit (SELU) activation. */ + public SELU() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.nn.selu(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java index 5d507b38483..02b2daae4d6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java @@ -41,25 +41,16 @@ * * @param the data type of the activation */ -public class Sigmoid extends Activation { +public class Sigmoid extends AbstractActivation { - /** - * Creates a Sigmoid activation. - * - * @param tf the TensorFlow Ops - */ - public Sigmoid(Ops tf) { - super(tf); + /** Creates a Sigmoid activation. */ + public Sigmoid() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.math.sigmoid(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java index 154e1ecc84a..3aa67a179ad 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java @@ -38,7 +38,7 @@ * * @param the data type of the activation */ -public class Softmax extends Activation { +public class Softmax extends AbstractActivation { private static final int AXIS_DEFAULT = -1; @@ -47,32 +47,24 @@ public class Softmax extends Activation { /** * Creates a softmax activation where the default axis is {@link #AXIS_DEFAULT} which indicates * the last dimension. - * - * @param tf the TensorFlow Ops */ - public Softmax(Ops tf) { - this(tf, AXIS_DEFAULT); + public Softmax() { + this(AXIS_DEFAULT); } /** * Creates a Softmax activation * - * @param tf the TensorFlow Ops * @param axis The dimension softmax would be performed on. */ - public Softmax(Ops tf, int axis) { - super(tf); + public Softmax(int axis) { + super(); this.axis = axis; } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { Shape shape = input.shape(); int numDimensions = shape.numDimensions(); if (numDimensions == 2) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java index 65a183ea047..8533de7852c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java @@ -32,25 +32,16 @@ * // 1.3132616e+00f, 2.0000000e+01f] * */ -public class Softplus extends Activation { +public class Softplus extends AbstractActivation { - /** - * Creates a Softplus activation function. - * - * @param tf the TensorFlow Ops - */ - public Softplus(Ops tf) { - super(tf); + /** Creates a Softplus activation function. */ + public Softplus() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.math.softplus(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java index 1f691e71862..249fa6077cd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java @@ -33,25 +33,16 @@ * * @param the data type of the activation */ -public class Softsign extends Activation { +public class Softsign extends AbstractActivation { - /** - * Creates a Softsign activation. - * - * @param tf the TensorFlow Ops - */ - public Softsign(Ops tf) { - super(tf); + /** Creates a Softsign activation. */ + public Softsign() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.nn.softsign(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java index d9f73a422d5..5007dd34555 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java @@ -40,7 +40,7 @@ * @param the data type of the activation * @see Ramachandran et al., 2017 */ -public class Swish extends Activation { +public class Swish extends AbstractActivation { /** * Creates a Swish activation, swish(x) = x * sigmoid(x). @@ -48,17 +48,14 @@ public class Swish extends Activation { *

    Swish activation function which returns x*sigmoid(x). It is a smooth, * non-monotonic function that consistently matches or outperforms ReLU on deep networks, it is * unbounded above and bounded below. - * - * @param tf the TensorFlow Ops */ - public Swish(Ops tf) { - super(tf); + public Swish() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand input) { - + public Operand call(Ops tf, Operand input) { // TODO Python Keras returns a "grad", which is an optimization not implemented in Java. return tf.math.mul(input, tf.math.sigmoid(input)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java index 4fe02eed048..37d4d811a0d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java @@ -33,20 +33,16 @@ * * @param the data type of the activation */ -public class Tanh extends Activation { +public class Tanh extends AbstractActivation { - /** - * Creates a Hyperbolic tangent activation. - * - * @param tf the TensorFlow Ops - */ - public Tanh(Ops tf) { - super(tf); + /** Creates a Hyperbolic tangent activation. */ + public Tanh() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.math.tanh(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java new file mode 100644 index 00000000000..266d01620bd --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** Base class for Constraints. AbstractConstraint subclasses impose constraints on weight values */ +public abstract class AbstractConstraint implements Constraint { + + public static final float EPSILON = 1e-7f; + + /** Creates a AbstractConstraint */ + public AbstractConstraint() {} + + /** + * Gets the element-wise square root. + * + * @param tf the TensorFlow Ops + * @param x the input Operand. + * @return the element-wise square root. + * @param The data type for the operand and result. + * @throws IllegalArgumentException if x is null + */ + protected Operand sqrt(Ops tf, Operand x) { + if (x == null) throw new IllegalArgumentException("Operand x must not be null"); + Class type = x.type(); + Operand zero = cast(tf, tf.constant(0), type); + Operand inf = cast(tf, tf.constant(Double.POSITIVE_INFINITY), type); + return tf.math.sqrt(tf.clipByValue(x, zero, inf)); + } + + /** + * Gets the element-wise value clipping. + * + * @param tf the TensorFlow Ops + * @param x the Operand to clip + * @param minValue the minimum value + * @param maxValue the maximum value + * @return the operand with clipped values + * @param The data type for the operand and result. + * @throws IllegalArgumentException if x is null + */ + protected Operand clip( + Ops tf, Operand x, double minValue, double maxValue) { + if (x == null) throw new IllegalArgumentException("Operand x must not be null"); + Class type = x.type(); + + double min = Math.min(minValue, maxValue); + double max = Math.max(minValue, maxValue); + + Operand minValueConstant = cast(tf, tf.constant(min), type); + Operand maxValueConstant = cast(tf, tf.constant(max), type); + return tf.clipByValue(x, minValueConstant, maxValueConstant); + } + + /** + * Calculates the norm of the weights along the axes + * + * @param tf the TensorFlow Ops + * @param weights the weights used to calculate the norms + * @param axes the axes along which to calculate weight norms. + * @param the data type for the weights and the result + * @return the norms + * @throws IllegalArgumentException if weights is null + */ + protected Operand norm(Ops tf, Operand weights, int[] axes) { + if (weights == null) throw new IllegalArgumentException("weights must not be null"); + return sqrt( + tf, + tf.reduceSum(tf.math.square(weights), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE))); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java index 306361959bf..97640b19cf8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,96 +16,16 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - -/** Base class for Constraints. Constraint subclasses impose constraints on weight values */ -public abstract class Constraint { - - public static final float EPSILON = 1e-7f; - - private final Ops tf; - - /** - * Creates a Constraint - * - * @param tf the TensorFlow Ops - */ - public Constraint(Ops tf) { - this.tf = tf; - } - +public interface Constraint { /** * Applies the constraint against the provided weights * + * @param tf the TensorFlow Ops * @param weights the weights * @return the constrained weights * @param the data type for weights and results. */ - public abstract Operand call(Operand weights); - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - public Ops getTF() { - return tf; - } - - /** - * Gets the element-wise square root. - * - * @param x the input Operand. - * @return the element-wise square root. - * @param The data type for the operand and result. - * @throws IllegalArgumentException if x is null - */ - protected Operand sqrt(Operand x) { - if (x == null) throw new IllegalArgumentException("Operand x must not be null"); - Class type = x.type(); - Operand zero = cast(tf, tf.constant(0), type); - Operand inf = cast(tf, tf.constant(Double.POSITIVE_INFINITY), type); - return tf.math.sqrt(tf.clipByValue(x, zero, inf)); - } - - /** - * Gets the element-wise value clipping. - * - * @param x the Operand to clip - * @param minValue the minimum value - * @param maxValue the maximum value - * @return the operand with clipped values - * @param The data type for the operand and result. - * @throws IllegalArgumentException if x is null - */ - protected Operand clip(Operand x, double minValue, double maxValue) { - if (x == null) throw new IllegalArgumentException("Operand x must not be null"); - Ops tf = getTF(); - Class type = x.type(); - - double min = Math.min(minValue, maxValue); - double max = Math.max(minValue, maxValue); - - Operand minValueConstant = cast(tf, tf.constant(min), type); - Operand maxValueConstant = cast(tf, tf.constant(max), type); - return tf.clipByValue(x, minValueConstant, maxValueConstant); - } - - /** - * Calculates the norm of the weights along the axes - * - * @param weights the weights used to calculate the norms - * @param axes the axes along which to calculate weight norms. - * @param the data type for the weights and the result - * @return the norms - * @throws IllegalArgumentException if weights is null - */ - protected Operand norm(Operand weights, int[] axes) { - if (weights == null) throw new IllegalArgumentException("weights must not be null"); - return sqrt( - tf.reduceSum(tf.math.square(weights), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE))); - } + Operand call(Ops tf, Operand weights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java index 1dae117b113..b9f082f54de 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java @@ -24,7 +24,7 @@ * Constrains the weights incident to each hidden unit to have a norm less than or equal to a * desired value. */ -public class MaxNorm extends Constraint { +public class MaxNorm extends AbstractConstraint { public static final double MAX_VALUE_DEFAULT = 2.0; public static final int AXIS_DEFAULT = 0; @@ -36,54 +36,48 @@ public class MaxNorm extends Constraint { /** * Create a MaxNorm constraint using {@link #MAX_VALUE_DEFAULT} for the max value and {@link * #AXIS_DEFAULT} for the axis. - * - * @param tf the TensorFlow Ops */ - public MaxNorm(Ops tf) { - this(tf, MAX_VALUE_DEFAULT, AXIS_DEFAULT); + public MaxNorm() { + this(MAX_VALUE_DEFAULT, AXIS_DEFAULT); } /** * Create a MaxNorm constraint using {@link #AXIS_DEFAULT} for the axis. * - * @param tf the TensorFlow Ops * @param maxValue the maximum norm for the incoming weights. */ - public MaxNorm(Ops tf, double maxValue) { - this(tf, maxValue, AXIS_DEFAULT); + public MaxNorm(double maxValue) { + this(maxValue, AXIS_DEFAULT); } /** * Create a MaxNorm constraint * - * @param tf the TensorFlow Ops * @param maxValue the maximum norm for the incoming weights. * @param axis axis along which to calculate weight norms. */ - public MaxNorm(Ops tf, double maxValue, int axis) { - this(tf, maxValue, new int[] {axis}); + public MaxNorm(double maxValue, int axis) { + this(maxValue, new int[] {axis}); } /** * Create a MaxNorm constraint * - * @param tf the TensorFlow Ops * @param maxValue the maximum norm for the incoming weights. * @param axes axes along which to calculate weight norms. */ - public MaxNorm(Ops tf, double maxValue, int[] axes) { - super(tf); + public MaxNorm(double maxValue, int[] axes) { + super(); this.maxValue = maxValue; this.axes = axes; } /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { - Ops tf = getTF(); + public Operand call(Ops tf, Operand weights) { Class type = weights.type(); - Operand norms = norm(weights, getAxes()); - Operand desired = clip(norms, 0f, this.getMaxValue()); + Operand norms = norm(tf, weights, getAxes()); + Operand desired = clip(tf, norms, 0f, this.getMaxValue()); return tf.math.mul( weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms))); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java index 04b21572e55..97e86d7693f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java @@ -21,7 +21,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** Constrains the weights to have the norm between a lower bound and an upper bound. */ -public class MinMaxNorm extends Constraint { +public class MinMaxNorm extends AbstractConstraint { public static final double MIN_VALUE_DEFAULT = 0.0; public static final double MAX_VALUE_DEFAULT = 1.0; public static final double RATE_DEFAULT = 1.0; @@ -47,48 +47,43 @@ public class MinMaxNorm extends Constraint { * Create a MinMaxNorm constraint using {@link #MIN_VALUE_DEFAULT} for the min value, {@link * #MAX_VALUE_DEFAULT} for the max value, {@link #RATE_DEFAULT} for the rate and {@link * #AXIS_DEFAULT} for the axis - * - * @param tf the TensorFlow Ops */ - public MinMaxNorm(Ops tf) { - this(tf, MIN_VALUE_DEFAULT, MAX_VALUE_DEFAULT, RATE_DEFAULT, AXIS_DEFAULT); + public MinMaxNorm() { + this(MIN_VALUE_DEFAULT, MAX_VALUE_DEFAULT, RATE_DEFAULT, AXIS_DEFAULT); } /** * Create a MinMaxNorm constraint using {@link #RATE_DEFAULT} for the rate and {@link * #AXIS_DEFAULT} for the axis * - * @param tf the TensorFlow Ops * @param minValue the minimum norm for the incoming weights. * @param maxValue the maximum norm for the incoming weights. */ - public MinMaxNorm(Ops tf, double minValue, double maxValue) { - this(tf, minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT); + public MinMaxNorm(double minValue, double maxValue) { + this(minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT); } /** * Create a MinMaxNorm constraint * - * @param tf the TensorFlow Ops * @param minValue the minimum norm for the incoming weights. * @param maxValue the maximum norm for the incoming weights. * @param rate the rate for enforcing the constraint. * @param axis integer, axis along which to calculate weight norms. */ - public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int axis) { - this(tf, minValue, maxValue, rate, new int[] {axis}); + public MinMaxNorm(double minValue, double maxValue, double rate, int axis) { + this(minValue, maxValue, rate, new int[] {axis}); } /** * Create a MinMaxNorm constraint * - * @param tf the TensorFlow Ops * @param minValue the minimum norm for the incoming weights. * @param maxValue the maximum norm for the incoming weights. * @param rate the rate for enforcing the constraint. * @param axes integer, axis along which to calculate weight norms. */ - public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int[] axes) { - super(tf); + public MinMaxNorm(double minValue, double maxValue, double rate, int[] axes) { + super(); this.minValue = minValue; this.maxValue = maxValue; this.rate = rate; @@ -97,15 +92,14 @@ public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int[] a /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { + public Operand call(Ops tf, Operand weights) { Class type = weights.type(); - Ops tf = getTF(); - Operand norms = norm(weights, getAxes()); + Operand norms = norm(tf, weights, getAxes()); Operand desired = tf.math.add( tf.math.mul( tf.dtypes.cast(tf.constant(this.getRate()), type), - clip(norms, this.getMinValue(), this.getMaxValue())), + clip(tf, norms, this.getMinValue(), this.getMaxValue())), tf.math.mul( tf.math.sub( tf.dtypes.cast(tf.constant(1), type), diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java index 0194b2fadb6..6a5677983fa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java @@ -19,21 +19,16 @@ import org.tensorflow.types.family.TNumber; /** Constrains the weights to be non-negative. */ -public class NonNeg extends Constraint { +public class NonNeg extends AbstractConstraint { - /** - * Create a NonNeg constraint - * - * @param tf the TensorFlow Ops - */ - public NonNeg(Ops tf) { - super(tf); + /** Create a NonNeg constraint */ + public NonNeg() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { - Ops tf = getTF(); + public Operand call(Ops tf, Operand weights) { Class type = weights.type(); return tf.math.mul( weights, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java index 70bb1a59785..fdd71945229 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java @@ -21,50 +21,43 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** Constrains the weights to have unit norm. */ -public class UnitNorm extends Constraint { +public class UnitNorm extends AbstractConstraint { public static final int AXIS_DEFAULT = 0; /** integer, axis along which to calculate weight norms. */ private final int[] axes; - /** - * Create a UnitNorm Constraint with the axis set to {@link #AXIS_DEFAULT} - * - * @param tf the TensorFlow Ops - */ - public UnitNorm(Ops tf) { - this(tf, AXIS_DEFAULT); + /** Create a UnitNorm AbstractConstraint with the axis set to {@link #AXIS_DEFAULT} */ + public UnitNorm() { + this(AXIS_DEFAULT); } /** - * Create a UnitNorm Constraint + * Create a UnitNorm AbstractConstraint * - * @param tf the TensorFlow Ops * @param axis axis along which to calculate weight norms. */ - public UnitNorm(Ops tf, int axis) { - this(tf, new int[] {axis}); + public UnitNorm(int axis) { + this(new int[] {axis}); } /** - * Create a UnitNorm Constraint + * Create a UnitNorm AbstractConstraint * - * @param tf the TensorFlow Ops * @param axes axes along which to calculate weight norms. */ - public UnitNorm(Ops tf, int[] axes) { - super(tf); + public UnitNorm(int[] axes) { + super(); this.axes = axes; } /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { + public Operand call(Ops tf, Operand weights) { Class type = weights.type(); - Ops tf = getTF(); return tf.math.div( - weights, tf.math.add(cast(tf, tf.constant(EPSILON), type), norm(weights, getAxes()))); + weights, tf.math.add(cast(tf, tf.constant(EPSILON), type), norm(tf, weights, getAxes()))); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java index 9c1fa9ac287..56e3d310280 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java @@ -14,29 +14,24 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TType; /** Abstract base class for all Initializers */ public abstract class BaseInitializer implements Initializer { - protected final Ops tf; + private final String name; - /** - * Creates an Initializer - * - * @param tf the TensorFlow Ops - */ - protected BaseInitializer(Ops tf) { - this.tf = tf; + /** Creates an Initializer */ + protected BaseInitializer() { + name = getClass().getSimpleName(); } /** - * Gets the TensorFlow Ops + * Gets the name for this initializer * - * @return the TensorFlow Ops + * @return the name for this initializer */ - public Ops getTF() { - return tf; + public String getName() { + return name; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java index 4a2df86d74b..508fb69fd55 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java @@ -21,6 +21,8 @@ import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors with a constant value. * @@ -30,7 +32,7 @@ * Constant<TFloat32> initializer = * new org.tensorflow.framework.initializers.Constant<>(tf, 3f); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The Type for the call operation @@ -45,11 +47,10 @@ public class Constant extends BaseInitializer { /** * Creates an Initializer that generates tensors with a constant value. * - * @param tf the TensorFlow Ops * @param value a long value used for the constant. */ - public Constant(Ops tf, long value) { - super(tf); + public Constant(long value) { + super(); longValue = value; doubleValue = 0; booleanValue = false; @@ -59,11 +60,10 @@ public Constant(Ops tf, long value) { /** * Creates an Initializer that generates tensors with a constant value. * - * @param tf the TensorFlow Ops * @param value a double value used for the constant. */ - public Constant(Ops tf, double value) { - super(tf); + public Constant(double value) { + super(); doubleValue = value; longValue = 0; booleanValue = false; @@ -73,11 +73,10 @@ public Constant(Ops tf, double value) { /** * Creates an Initializer that generates tensors with a constant value. * - * @param tf the TensorFlow Ops * @param value a boolean value used for the constant. */ - public Constant(Ops tf, boolean value) { - super(tf); + public Constant(boolean value) { + super(); booleanValue = value; doubleValue = 0; longValue = 0; @@ -86,17 +85,19 @@ public Constant(Ops tf, boolean value) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { - throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName()); + throw new IllegalArgumentException( + "Tensor type must be numeric or boolean: " + type.getSimpleName()); } switch (valueType) { case LONG: - return tf.fill(dims, tf.dtypes.cast(tf.constant(longValue), type)); + return tf.fill(dims, cast(tf, tf.constant(longValue), type)); case DOUBLE: - return tf.fill(dims, tf.dtypes.cast(tf.constant(doubleValue), type)); + return tf.fill(dims, cast(tf, tf.constant(doubleValue), type)); default: - return tf.fill(dims, tf.dtypes.cast(tf.constant(booleanValue), type)); + return tf.fill(dims, cast(tf, tf.constant(booleanValue), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java index 894bd073758..4a39c3839f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.initializers; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; /** @@ -43,7 +42,7 @@ * new org.tensorflow.framework.initializers.Glorot<>(tf, * Distribution.TRUNCATED_NORMAL, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

    Glorot Uniform: @@ -54,12 +53,14 @@ * new org.tensorflow.framework.initializers.Glorot<>(tf, * Distribution.UNIFORM, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

    NOTE: + * *

    For a GlorotNormal equivalent initializer, use {@link * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. + * *

    For a GlorotUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. * @@ -74,13 +75,12 @@ public class Glorot extends VarianceScaling { /** * Creates a Glorot initializer * - * @param tf the TensorFlow Ops * @param distribution The distribution type for the Glorot initializer. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. * @see VarianceScaling.Distribution */ - public Glorot(Ops tf, Distribution distribution, long seed) { - super(tf, SCALE, Mode.FAN_AVG, distribution, seed); + public Glorot(Distribution distribution, long seed) { + super(SCALE, Mode.FAN_AVG, distribution, seed); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java index 3a91b72b0d0..4a9fa8a7849 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; /** @@ -38,7 +37,7 @@ * new org.tensorflow.framework.initializers.He<>(tf, * Distribution.TRUNCATED_NORMAL, seed);); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

    He Uniform: @@ -49,14 +48,16 @@ * new org.tensorflow.framework.initializers.He<>(tf, * Distribution.UNIFORM, seed);); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

    NOTE: + * *

    For an HeNormal equivalent initializer, use {@link * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. - *

    For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} - * for the distribution parameter. + * + *

    For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} for + * the distribution parameter. * * @param The TType for the call operation * @see extends VarianceScaling { /** * Creates an He Initializer * - * @param tf the TensorFlow Ops * @param distribution The distribution type for the He initializer. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. * @see VarianceScaling.Distribution */ - public He(Ops tf, Distribution distribution, long seed) { - super(tf, SCALE, Mode.FAN_IN, distribution, seed); + public He(Distribution distribution, long seed) { + super(SCALE, Mode.FAN_IN, distribution, seed); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java index f672c9f1e85..34a77520406 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java @@ -21,6 +21,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates the identity matrix. * @@ -32,40 +34,34 @@ * Identity<TFloat32> initializer = * new org.tensorflow.framework.initializers.Identity<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation */ public class Identity extends BaseInitializer { public static final double GAIN_DEFAULT = 1.0; - private final double gain; - /** - * Creates an Initializer that generates the identity matrix. - * - * @param tf the TensorFlow Ops - */ - public Identity(Ops tf) { - super(tf); - this.gain = GAIN_DEFAULT; + /** Creates an Initializer that generates the identity matrix. */ + public Identity() { + this(GAIN_DEFAULT); } /** * Creates an Initializer that generates the identity matrix. * - * @param tf the TensorFlow Ops * @param gain the gain to be applied to the Identity Matrix */ - public Identity(Ops tf, double gain) { - super(tf); + public Identity(double gain) { + super(); this.gain = gain; } /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + Shape shape = ShapeUtils.toShape(tf.scope(), dims); if (shape.numDimensions() != 2) { throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions()); @@ -75,9 +71,9 @@ public Operand call(Operand dims, Class type) { Shape diagShape = Shape.of(diagSize); Operand op; - Operand zero = tf.dtypes.cast(tf.constant(0), type); + Operand zero = cast(tf, tf.constant(0), type); Operand diagOnes = - tf.fill(tf.constant(diagShape.asArray()), tf.dtypes.cast(tf.constant(1.0), type)); + tf.fill(tf.constant(diagShape.asArray()), cast(tf, tf.constant(1.0), type)); if (isSquare) { op = tf.linalg.matrixDiag( @@ -91,6 +87,6 @@ public Operand call(Operand dims, Class type) { op = tf.linalg.matrixSetDiag(zeroMatrix, diagOnes, tf.constant(0)); } - return tf.math.mul(op, tf.dtypes.cast(tf.constant(gain), type)); + return tf.math.mul(op, cast(tf, tf.constant(gain), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java index 4beb218783b..d6593b770e2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.initializers; import org.tensorflow.Operand; +import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -23,14 +24,18 @@ * * @param The data Type for initializer operation */ +@FunctionalInterface public interface Initializer { /** * Generates the operation used to perform the initialization. * + * @param tf the TensorFlow Ops * @param dims the shape dimensions * @param type the type of tensor + * @throws IllegalStateException if the object has not been initialized with the TensorFlow + * Platform. * @return An operand for the initialization. */ - Operand call(Operand dims, Class type); + Operand call(Ops tf, Operand dims, Class type); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java index 38e68ef688b..364c5fb9285 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; /** @@ -27,7 +26,7 @@ * stddev = sqrt(1 / fanIn) where fanIn is the number of input units in the * weight tensor. * - *

    If the distribution is UNIFORM, itraws samples from a uniform distribution within + *

    If the distribution is UNIFORM, it draws samples from a uniform distribution within * [-limit, limit], where limit = Math.sqrt(3 / fanIn) (fanIn is * the number of input units in the weight tensor) * @@ -41,7 +40,7 @@ * new org.tensorflow.framework.initializers.LeCunNormal<>(tf, * Distribution.TRUNCATED_NORMAL, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

    LeCun Uniform: @@ -52,14 +51,15 @@ * new org.tensorflow.framework.initializers.LeCunNormal<>(tf, * Distribution.UNIFORM, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * * * *

    NOTE: * * - *

    For a LeCunNormal equivalent initializer, use {@link VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. * + *

    For a LeCunNormal equivalent initializer, use {@link + * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. * * *

    For a LeCunUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * * for the distribution parameter. * @@ -79,12 +79,11 @@ public class LeCun extends VarianceScaling { /** * Creates a LeCunNormal Initializer * - * @param tf the TensorFlow Ops * @param distribution The distribution type for the Glorot initializer. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public LeCun(Ops tf, Distribution distribution, long seed) { - super(tf, 1.0, Mode.FAN_IN, distribution, seed); + public LeCun(Distribution distribution, long seed) { + super(1.0, Mode.FAN_IN, distribution, seed); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java index b8eb0c418e9..6e818d30bd7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java @@ -21,6 +21,8 @@ import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors initialized to 1. * @@ -30,7 +32,7 @@ * Ones<TFloat32> initializer = * new org.tensorflow.framework.initializers.Ones<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -46,21 +48,21 @@ public class Ones extends BaseInitializer { * Ones<TFloat32> initializer = * new org.tensorflow.framework.initializers.Ones<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param tf the TensorFlow Ops */ - public Ones(Ops tf) { - super(tf); + public Ones() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { - throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName()); + throw new IllegalArgumentException( + "Tensor type must be numeric or boolean: " + type.getSimpleName()); } - return tf.fill(dims, tf.dtypes.cast(tf.constant(1.0), type)); + return tf.fill(dims, cast(tf, tf.constant(1), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java index a5b466e118e..519d0cd042e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java @@ -23,6 +23,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates an orthogonal matrix. * @@ -42,7 +44,7 @@ * Orthogonal<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.Orthogonal<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -57,31 +59,30 @@ public class Orthogonal extends BaseInitializer { /** * Creates an Orthogonal Initializer using {@link #GAIN_DEFAULT} for the gain. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public Orthogonal(Ops tf, long seed) { - this(tf, GAIN_DEFAULT, seed); + public Orthogonal(long seed) { + this(GAIN_DEFAULT, seed); } /** * Creates an Orthogonal Initializer * - * @param tf the TensorFlow Ops * @param gain the gain to be applied to the Matrix. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public Orthogonal(Ops tf, double gain, long seed) { - super(tf); + public Orthogonal(double gain, long seed) { + super(); this.gain = gain; this.seed = seed; } /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + Shape dimsShape = ShapeUtils.toShape(tf.scope(), dims); if (dimsShape.numDimensions() < 2) { throw new IllegalArgumentException( @@ -101,10 +102,10 @@ public Operand call(Operand dims, Class type) { Output qo = qrOp.q(); Output ro = qrOp.r(); Operand diagOp = - tf.linalg.matrixDiagPart(ro, tf.constant(0), tf.dtypes.cast(tf.constant(0), type)); + tf.linalg.matrixDiagPart(ro, tf.constant(0), cast(tf, tf.constant(0), type)); Operand qop = tf.math.mul(qo, tf.math.sign(diagOp)); if (numRows < numCols) qop = tf.linalg.transpose(qop, null); - return tf.math.mul(qop, tf.dtypes.cast(tf.constant(this.gain), type)); + return tf.math.mul(qop, cast(tf, tf.constant(this.gain), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java index 38ab194a56b..9a52a641416 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java @@ -19,6 +19,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors with a normal distribution. * @@ -29,7 +31,7 @@ * RandomNormal<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.RandomNormal<>(tf, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -47,37 +49,34 @@ public class RandomNormal extends BaseInitializer { * Creates the RandomUniform initializer using {@link #MEAN_DEFAULT} for the mean and {@link * #STDDEV_DEFAULT} for the standard deviation. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomNormal(Ops tf, long seed) { - this(tf, MEAN_DEFAULT, STDDEV_DEFAULT, seed); + public RandomNormal(long seed) { + this(MEAN_DEFAULT, STDDEV_DEFAULT, seed); } /** * Creates the RandomUniform initializer using {@link #STDDEV_DEFAULT} for the standard deviation. * - * @param tf the TensorFlow Ops * @param mean Mean of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomNormal(Ops tf, double mean, long seed) { - this(tf, mean, STDDEV_DEFAULT, seed); + public RandomNormal(double mean, long seed) { + this(mean, STDDEV_DEFAULT, seed); } /** * creates the RandomUniform initializer * - * @param tf the TensorFlow Ops * @param mean Mean of the random values to generate. * @param stddev Standard deviation of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomNormal(Ops tf, double mean, double stddev, long seed) { - super(tf); + public RandomNormal(double mean, double stddev, long seed) { + super(); this.mean = mean; this.stddev = stddev; this.seed = seed; @@ -85,10 +84,11 @@ public RandomNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + long[] seeds = {seed, 0}; Operand distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); - Operand op = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.stddev), type)); - return tf.math.add(op, tf.dtypes.cast(tf.constant(mean), type)); + Operand op = tf.math.mul(distOp, cast(tf, tf.constant(this.stddev), type)); + return tf.math.add(op, cast(tf, tf.constant(mean), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java index 787af15f709..7288024f5b8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java @@ -21,6 +21,8 @@ import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors with a uniform distribution. * @@ -31,7 +33,7 @@ * RandomUniform<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.RandomUniform<>(tf, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -46,28 +48,26 @@ public class RandomUniform extends BaseInitializer { private final long seed; /** - * Creates a RandomUniform initializer using {@link #MINVAL_DEFAULT} for the minval and - * {@link #MAXVAL_DEFAULT} for the maxval + * Creates a RandomUniform initializer using {@link #MINVAL_DEFAULT} for the minval and {@link + * #MAXVAL_DEFAULT} for the maxval * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomUniform(Ops tf, long seed) { - this(tf, MINVAL_DEFAULT, MAXVAL_DEFAULT, seed); + public RandomUniform(long seed) { + this(MINVAL_DEFAULT, MAXVAL_DEFAULT, seed); } /** * Creates a RandomUniform initializer * - * @param tf the TensorFlow Ops * @param minval Lower bound of the range of random values to generate (inclusive). * @param maxval Upper bound of the range of random values to generate (exclusive). * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomUniform(Ops tf, double minval, double maxval, long seed) { - super(tf); + public RandomUniform(double minval, double maxval, long seed) { + super(); this.minval = minval; this.maxval = maxval; this.seed = seed; @@ -75,26 +75,27 @@ public RandomUniform(Ops tf, double minval, double maxval, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + Operand distOp; if (TIntegral.class.isAssignableFrom(type)) { RandomUniformInt.Options options = RandomUniformInt.seed(this.seed); distOp = tf.random.randomUniformInt( dims, - tf.dtypes.cast(tf.constant(this.minval), type), - tf.dtypes.cast(tf.constant(this.maxval), type), + cast(tf, tf.constant(this.minval), type), + cast(tf, tf.constant(this.maxval), type), options); } else { long[] seeds = {seed, 0}; distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); if (this.minval == 0) { if (this.maxval != 1.0) { - distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval), type)); + distOp = tf.math.mul(distOp, cast(tf, tf.constant(this.maxval), type)); } } else { - distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval - this.minval), type)); - distOp = tf.math.add(distOp, tf.dtypes.cast(tf.constant(this.minval), type)); + distOp = tf.math.mul(distOp, cast(tf, tf.constant(this.maxval - this.minval), type)); + distOp = tf.math.add(distOp, cast(tf, tf.constant(this.minval), type)); } } return distOp; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java index d3cfec26338..8069d5d9c7d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java @@ -19,6 +19,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates a truncated normal distribution. * @@ -29,7 +31,7 @@ * TruncatedNormal<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.TruncatedNormal<>(tf, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -47,25 +49,23 @@ public class TruncatedNormal extends BaseInitializer { * Creates a TruncatedNormal Initializer using {@link #MEAN_DEFAULT} for the mean and {@link * #STDDEV_DEFAULT} for the standard deviation. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public TruncatedNormal(Ops tf, long seed) { - this(tf, MEAN_DEFAULT, STDDEV_DEFAULT, seed); + public TruncatedNormal(long seed) { + this(MEAN_DEFAULT, STDDEV_DEFAULT, seed); } /** * Creates a TruncatedNormal Initializer. * - * @param tf the TensorFlow Ops * @param mean Mean of the random values to generate. * @param stddev Standard deviation of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public TruncatedNormal(Ops tf, double mean, double stddev, long seed) { - super(tf); + public TruncatedNormal(double mean, double stddev, long seed) { + super(); this.mean = mean; this.stddev = stddev; this.seed = seed; @@ -73,11 +73,12 @@ public TruncatedNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { - long[] seeds = {seed,0}; + public Operand call(Ops tf, Operand dims, Class type) { + + long[] seeds = {seed, 0}; Operand distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); return tf.math.add( - tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)), - tf.dtypes.cast(tf.constant(mean), type)); + tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)), + cast(tf, tf.constant(mean), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java index 5d951450505..a04e4a9a378 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java @@ -21,11 +21,13 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer capable of adapting its scale to the shape of weights tensors. * - *

    With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from - * a truncated/untruncated normal distribution with a mean of zero and a standard deviation (after + *

    With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from a + * truncated/untruncated normal distribution with a mean of zero and a standard deviation (after * truncation, if used) stddev = Math.sqrt(scale / n), where n is: * *

      @@ -46,7 +48,7 @@ * new org.tensorflow.framework.initializers.VarianceScaling<>( * tf, scale, Mode.FAN_IN, Distribution.UNIFORM, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -64,28 +66,25 @@ public class VarianceScaling extends BaseInitializer { private final Distribution distribution; private final long seed; - /** * Creates a VarianceScaling Initializer * - * @param tf the TensorFlow Ops * @param seed sed to create random seeds. */ - public VarianceScaling(Ops tf, long seed) { - this(tf, SCALE_DEFAULT, MODE_DEFAULT, DISTRIBUTION_DEFAULT, seed); + public VarianceScaling(long seed) { + this(SCALE_DEFAULT, MODE_DEFAULT, DISTRIBUTION_DEFAULT, seed); } /** * Creates a VarianceScaling Initializer * - * @param tf the TensorFlow Ops * @param scale Scaling factor (positive float). * @param mode the mode for the variance * @param distribution Random distribution to use. * @param seed Used to create random seeds. */ - public VarianceScaling(Ops tf, double scale, Mode mode, Distribution distribution, long seed) { - super(tf); + public VarianceScaling(double scale, Mode mode, Distribution distribution, long seed) { + super(); if (scale <= 0.0) { throw new IllegalArgumentException("scale must be greater than 0, got " + scale); } @@ -97,8 +96,9 @@ public VarianceScaling(Ops tf, double scale, Mode mode, Distribution distributio /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { - Shape shape = ShapeUtils.toShape(this.tf.scope(), dims); + public Operand call(Ops tf, Operand dims, Class type) { + + Shape shape = ShapeUtils.toShape(tf.scope(), dims); double lscale = this.scale; double[] fans /* fanIn, fanOut */ = computeFans(shape); switch (mode) { @@ -119,18 +119,18 @@ public Operand call(Operand dims, Class type) { switch (distribution) { case TRUNCATED_NORMAL: distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); - stddev = Math.sqrt(lscale) / .87962566103423978; - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); + stddev = Math.sqrt(lscale) / 0.87962566103423978; + mulOp = tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)); break; case NORMAL: distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); stddev = Math.sqrt(lscale); - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); + mulOp = tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)); break; case UNIFORM: distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); stddev = Math.sqrt(3.0 * lscale); - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); + mulOp = tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)); break; } return mulOp; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java index 4298493ac44..f581d247deb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java @@ -28,24 +28,21 @@ * Zeros<TFloat32> initializer = * new org.tensorflow.framework.initializers.Zeros<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation */ public class Zeros extends BaseInitializer { - /** - * Creates an Initializer that sets all values to one. - * - * @param tf the TensorFlow Ops - */ - public Zeros(Ops tf) { - super(tf); + /** Creates an Initializer that sets all values to one. */ + public Zeros() { + super(); } @Override - public Operand call(Operand dims, Class dtype) { - return tf.zeros(dims, dtype); + public Operand call(Ops tf, Operand dims, Class type) { + + return tf.zeros(dims, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index 3417c07372a..0c7c6abf8af 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -35,7 +36,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * BinaryCrossentropy bce = new BinaryCrossentropy(tf); - * Operand<TFloat32> result = bce.call(labels, predictions); + * Operand<TFloat32> result = bce.call(Ops tf, labels, predictions); * // produces 0.815 * * @@ -43,7 +44,7 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
      - *    Operand<TFloat32> result = bce.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = bce.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 0.458f
        * 
      * @@ -51,7 +52,7 @@ * *
        *    BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = bce.call(labels, predictions);
      + *    Operand<TFloat32> result = bce.call(Ops tf, labels, predictions);
        *    // produces 1.630f
        * 
      * @@ -59,11 +60,11 @@ * *
        *    BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = bce.call(labels, predictions);
      + *    Operand<TFloat32> result = bce.call(Ops tf, labels, predictions);
        *    // produces [0.916f, 0.714f]
        * 
      */ -public class BinaryCrossentropy extends Loss { +public class BinaryCrossentropy extends AbstractLoss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; @@ -71,70 +72,63 @@ public class BinaryCrossentropy extends Loss { private final float labelSmoothing; /** - * Creates a Binary Crossentropy Loss using {@link Class#getSimpleName()} as the loss name, {@link - * #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing and a - * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Binary Crossentropy AbstractLoss using {@link Class#getSimpleName()} as the loss + * name, {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for + * labelSmoothing and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public BinaryCrossentropy(Ops tf) { - this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); + public BinaryCrossentropy() { + this(null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); } /** * Creates a Binary Crossentropy loss using {@link Class#getSimpleName()} as the loss name, {@link * #FROM_LOGITS_DEFAULT} for fromLogits, and {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public BinaryCrossentropy(Ops tf, Reduction reduction) { - this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction); + public BinaryCrossentropy(Reduction reduction) { + this(null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction); } /** * Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name, * labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT}, a reduction of {@link - * Loss#REDUCTION_DEFAULT}, + * AbstractLoss#REDUCTION_DEFAULT}, * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public BinaryCrossentropy(Ops tf, boolean fromLogits) { - this(tf, null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); + public BinaryCrossentropy(boolean fromLogits) { + this(null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); } /** * Creates a Binary Crossentropy loss using labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT} a - * reduction of {@link Loss#REDUCTION_DEFAULT}. + * reduction of {@link AbstractLoss#REDUCTION_DEFAULT}. * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public BinaryCrossentropy(Ops tf, String name, boolean fromLogits) { - this(tf, name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); + public BinaryCrossentropy(String name, boolean fromLogits) { + this(name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); } /** * Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name, - * and a reduction of {@link Loss#REDUCTION_DEFAULT}. + * and a reduction of {@link AbstractLoss#REDUCTION_DEFAULT}. * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing * correspond to heavier smoothing. */ - public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { - this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT); + public BinaryCrossentropy(boolean fromLogits, float labelSmoothing) { + this(null, fromLogits, labelSmoothing, REDUCTION_DEFAULT); } /** - * Creates a Binary Crossentropy loss using a reduction of {@link Loss#REDUCTION_DEFAULT}. + * Creates a Binary Crossentropy loss using a reduction of {@link AbstractLoss#REDUCTION_DEFAULT}. * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, @@ -142,14 +136,13 @@ public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing * correspond to heavier smoothing. */ - public BinaryCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { - this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT); + public BinaryCrossentropy(String name, boolean fromLogits, float labelSmoothing) { + this(name, fromLogits, labelSmoothing, REDUCTION_DEFAULT); } /** * Creates a Binary Crossentropy loss * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, @@ -157,14 +150,13 @@ public BinaryCrossentropy(Ops tf, String name, boolean fromLogits, float labelSm * correspond to heavier smoothing. * @param reduction Type of Reduction to apply to the loss. */ - public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing, Reduction reduction) { - this(tf, null, fromLogits, labelSmoothing, reduction); + public BinaryCrossentropy(boolean fromLogits, float labelSmoothing, Reduction reduction) { + this(null, fromLogits, labelSmoothing, reduction); } /** * Creates a Binary Crossentropy loss * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, @@ -175,8 +167,8 @@ public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing, Redu * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public BinaryCrossentropy( - Ops tf, String name, boolean fromLogits, float labelSmoothing, Reduction reduction) { - super(tf, name, reduction); + String name, boolean fromLogits, float labelSmoothing, Reduction reduction) { + super(name, reduction); if (labelSmoothing < 0 || labelSmoothing > 1) throw new IllegalArgumentException( "labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); @@ -207,24 +199,25 @@ public BinaryCrossentropy( */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = LossesHelper.rangeCheck( - getTF(), + tf, "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + cast(tf, tf.constant(0), predictions.type()), + cast(tf, tf.constant(1), predictions.type())); } else { lPredictions = predictions; } Operand losses = - Losses.binaryCrossentropy(getTF(), labels, lPredictions, fromLogits, labelSmoothing); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Losses.binaryCrossentropy(tf, labels, lPredictions, fromLogits, labelSmoothing); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 5aac163c1e4..7d65353b004 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -37,7 +38,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.05f, 0.95f, 0f}, {0.1f, 0.8f, 0.1f}}); * CategoricalCrossentropy cce = new CategoricalCrossentropy(tf); - * Operand<TFloat32> result = cce.call(labels, predictions); + * Operand<TFloat32> result = cce.call(Ops tf, labels, predictions); * // produces 1.177 * * @@ -45,15 +46,15 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.3f, 0.7f});
      - *    Operand<TFloat32> result = cce.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = cce.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 0.814f
        * 
      * *

      Using SUM reduction type: * *

      - *    CategoricalCrossentropy cce = new CategoricalCrossentropy(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = cce.call(labels, predictions);
      + *    CategoricalCrossentropy cce = new CategoricalCrossentropy(Reduction.SUM);
      + *    Operand<TFloat32> result = cce.call(Ops tf, labels, predictions);
        *    // produces 2.354f
        * 
      * @@ -61,12 +62,12 @@ * *
        *    CategoricalCrossentropy cce =
      - *        new CategoricalCrossentropy(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = cce.call(labels, predictions);
      + *        new CategoricalCrossentropy(Reduction.NONE);
      + *    Operand<TFloat32> result = cce.call(Ops tf, labels, predictions);
        *    // produces [0.0513f, 2.303f]
        * 
      */ -public class CategoricalCrossentropy extends Loss { +public class CategoricalCrossentropy extends AbstractLoss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; public static final int DEFAULT_AXIS = Losses.CHANNELS_LAST; @@ -76,98 +77,90 @@ public class CategoricalCrossentropy extends Loss { private final int axis; /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for - * labelSmoothing, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and an axis of {@link - * #DEFAULT_AXIS} - * - * @param tf the TensorFlow Ops + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name, {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for + * labelSmoothing, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and an axis + * of {@link #DEFAULT_AXIS} */ - public CategoricalCrossentropy(Ops tf) { - this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy() { + this(null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link #FROM_LOGITS_DEFAULT} for fromLogits, - * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link #FROM_LOGITS_DEFAULT} for + * fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a AbstractLoss Reduction of + * {@link AbstractLoss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this loss */ - public CategoricalCrossentropy(Ops tf, String name) { - this(tf, name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(String name) { + this(name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name, {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for * labelSmoothing and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to loss. */ - public CategoricalCrossentropy(Ops tf, Reduction reduction) { - this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); + public CategoricalCrossentropy(Reduction reduction) { + this(null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link - * #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, and an axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss {@link #FROM_LOGITS_DEFAULT} for fromLogits, + * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this loss * @param reduction Type of Reduction to apply to loss. */ - public CategoricalCrossentropy(Ops tf, String name, Reduction reduction) { - this(tf, name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); + public CategoricalCrossentropy(String name, Reduction reduction) { + this(name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a AbstractLoss Reduction of + * {@link AbstractLoss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public CategoricalCrossentropy(Ops tf, boolean fromLogits) { - this(tf, null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(boolean fromLogits) { + this(null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link #LABEL_SMOOTHING_DEFAULT} for - * labelSmoothing, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and a channel axis of - * {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link #LABEL_SMOOTHING_DEFAULT} for + * labelSmoothing, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and a + * channel axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { - this(tf, name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(String name, boolean fromLogits) { + this(name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and a channel + * axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and * 0.9 for label 1 */ - public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { - this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(boolean fromLogits, float labelSmoothing) { + this(null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are @@ -175,15 +168,14 @@ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) *
      means that we will use a value of 0.1 for label 0 and * 0.9 for label 1 */ - public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { - this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(String name, boolean fromLogits, float labelSmoothing) { + this(name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name - * and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name and a channel axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are * smoothed, meaning the confidence on label values are relaxed. e.g. x=0.2 means @@ -191,15 +183,13 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la * for label 1 * @param reduction Type of Reduction to apply to loss. */ - public CategoricalCrossentropy( - Ops tf, boolean fromLogits, float labelSmoothing, Reduction reduction) { - this(tf, null, fromLogits, labelSmoothing, reduction, DEFAULT_AXIS); + public CategoricalCrossentropy(boolean fromLogits, float labelSmoothing, Reduction reduction) { + this(null, fromLogits, labelSmoothing, reduction, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss + * Creates a categorical cross entropy AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are @@ -213,13 +203,8 @@ public CategoricalCrossentropy( * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public CategoricalCrossentropy( - Ops tf, - String name, - boolean fromLogits, - float labelSmoothing, - Reduction reduction, - int axis) { - super(tf, name, reduction); + String name, boolean fromLogits, float labelSmoothing, Reduction reduction, int axis) { + super(name, reduction); if (labelSmoothing < 0 || labelSmoothing > 1) throw new IllegalArgumentException( "labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); @@ -251,24 +236,24 @@ public CategoricalCrossentropy( */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = LossesHelper.rangeCheck( - getTF(), + tf, "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + cast(tf, tf.constant(0), predictions.type()), + cast(tf, tf.constant(1), predictions.type())); } else { lPredictions = predictions; } Operand losses = - Losses.categoricalCrossentropy( - getTF(), labels, lPredictions, fromLogits, labelSmoothing, axis); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Losses.categoricalCrossentropy(tf, labels, lPredictions, fromLogits, labelSmoothing, axis); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index 73837ed1756..c9987fb0884 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -35,7 +36,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * CategoricalHinge categoricalHinge = new CategoricalHinge(tf); - * Operand<TFloat32> result = categoricalHinge.call(labels, predictions); + * Operand<TFloat32> result = categoricalHinge.call(Ops tf, labels, predictions); * // produces 1.4 * * @@ -43,7 +44,7 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1f, 0.f});
      - *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = categoricalHinge.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 0.6f
        * 
      * @@ -51,7 +52,7 @@ * *
        *    CategoricalHinge categoricalHinge = new CategoricalHinge(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions);
      + *    Operand<TFloat32> result = categoricalHinge.call(Ops tf, labels, predictions);
        *    // produces 2.8f
        * 
      * @@ -60,48 +61,45 @@ *
        *    CategoricalHinge categoricalHinge =
        *        new CategoricalHinge(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions);
      + *    Operand<TFloat32> result = categoricalHinge.call(Ops tf, labels, predictions);
        *    // produces [1.2f, 1.6f]
        * 
      */ -public class CategoricalHinge extends Loss { +public class CategoricalHinge extends AbstractLoss { /** - * Creates a Categorical Hinge Loss using {@link Class#getSimpleName()} as the loss name and a - * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Categorical Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public CategoricalHinge(Ops tf) { - super(tf); + public CategoricalHinge() { + super(); } /** - * Creates a Categorical Hinge Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Categorical Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public CategoricalHinge(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public CategoricalHinge(Reduction reduction) { + super(null, reduction); } /** * Creates a Categorical Hinge * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public CategoricalHinge(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public CategoricalHinge(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.categoricalHinge(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java index 0a18d93caf3..ac810139d71 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -40,7 +41,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 0.f}, {1.f, 1.f}}); * CosineSimilarity cosineLoss = new CosineSimilarity(tf); - * Operand<TFloat32> result = cosineLoss.call(labels, predictions); + * Operand<TFloat32> result = cosineLoss.call(Ops tf, labels, predictions); * // produces -0.5 * * @@ -48,7 +49,7 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
      - *    Operand<TFloat32> result = cosineLoss.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = cosineLoss.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces -0.0999f
        * 
      * @@ -56,7 +57,7 @@ * *
        *    CosineSimilarity cosineLoss = new CosineSimilarity(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = cosineLoss.call(labels, predictions);
      + *    Operand<TFloat32> result = cosineLoss.call(Ops tf, labels, predictions);
        *    // produces -0.999f
        * 
      * @@ -64,165 +65,155 @@ * *
        *    CosineSimilarity cosineLoss = new CosineSimilarity(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = cosineLoss.call(labels, predictions);
      + *    Operand<TFloat32> result = cosineLoss.call(Ops tf, labels, predictions);
        *    // produces [-0.f, -0.999f]
        * 
      */ -public class CosineSimilarity extends Loss { +public class CosineSimilarity extends AbstractLoss { public static final int DEFAULT_AXIS = -1; public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO; private final int[] axis; /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, an axis - * of {@link #DEFAULT_AXIS}, and a Loss Reduction of {@link #DEFAULT_REDUCTION} - * - * @param tf the TensorFlow Ops + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name, + * an axis of {@link #DEFAULT_AXIS}, and a AbstractLoss Reduction of {@link #DEFAULT_REDUCTION} */ - public CosineSimilarity(Ops tf) { + public CosineSimilarity() { - this(tf, null, DEFAULT_AXIS, DEFAULT_REDUCTION); + this(null, DEFAULT_AXIS, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using an axis of {@link #DEFAULT_AXIS}, and a Loss Reduction - * of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using an axis of {@link #DEFAULT_AXIS}, and a + * AbstractLoss Reduction of {@link #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param name the name of the loss */ - public CosineSimilarity(Ops tf, String name) { + public CosineSimilarity(String name) { - this(tf, name, DEFAULT_AXIS, DEFAULT_REDUCTION); + this(name, DEFAULT_AXIS, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a - * Loss Reduction of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name, + * and a AbstractLoss Reduction of {@link #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param axis The dimension along which the cosine similarity is computed. */ - public CosineSimilarity(Ops tf, int axis) { + public CosineSimilarity(int axis) { - this(tf, null, axis, DEFAULT_REDUCTION); + this(null, axis, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a - * Loss Reduction of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name, + * and a AbstractLoss Reduction of {@link #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param axis The dimension along which the cosine similarity is computed. */ - public CosineSimilarity(Ops tf, int[] axis) { + public CosineSimilarity(int[] axis) { - this(tf, null, axis, DEFAULT_REDUCTION); + this(null, axis, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using a AbstractLoss Reduction of {@link + * #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param axis The dimension along which the cosine similarity is computed. */ - public CosineSimilarity(Ops tf, String name, int axis) { + public CosineSimilarity(String name, int axis) { - this(tf, name, axis, DEFAULT_REDUCTION); + this(name, axis, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using a AbstractLoss Reduction of {@link + * #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param axis The dimension along which the cosine similarity is computed. */ - public CosineSimilarity(Ops tf, String name, int[] axis) { + public CosineSimilarity(String name, int[] axis) { - this(tf, name, axis, DEFAULT_REDUCTION); + this(name, axis, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name and an - * axis of {@link #DEFAULT_AXIS} + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, Reduction reduction) { + public CosineSimilarity(Reduction reduction) { - this(tf, null, DEFAULT_AXIS, reduction); + this(null, DEFAULT_AXIS, reduction); } /** - * Creates a Cosine Similarity Loss using an axis of {@link #DEFAULT_AXIS} + * Creates a Cosine Similarity AbstractLoss using an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, String name, Reduction reduction) { + public CosineSimilarity(String name, Reduction reduction) { - this(tf, name, DEFAULT_AXIS, reduction); + this(name, DEFAULT_AXIS, reduction); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param axis The dimension along which the cosine similarity is computed. * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, int axis, Reduction reduction) { + public CosineSimilarity(int axis, Reduction reduction) { - this(tf, null, new int[] {axis}, reduction); + this(null, new int[] {axis}, reduction); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param axis The dimension along which the cosine similarity is computed. * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, int[] axis, Reduction reduction) { + public CosineSimilarity(int[] axis, Reduction reduction) { - this(tf, null, axis, reduction); + this(null, axis, reduction); } /** - * Creates a Cosine Similarity Loss + * Creates a Cosine Similarity AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param axis The dimension along which the cosine similarity is computed. * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) { - this(tf, name, new int[] {axis}, reduction); + public CosineSimilarity(String name, int axis, Reduction reduction) { + this(name, new int[] {axis}, reduction); } /** - * Creates a Cosine Similarity Loss + * Creates a Cosine Similarity AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param axis The dimension along which the cosine similarity is computed. * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, String name, int[] axis, Reduction reduction) { - super(tf, name, reduction); + public CosineSimilarity(String name, int[] axis, Reduction reduction) { + super(name, reduction); this.axis = axis; } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.cosineSimilarity(tf, labels, predictions, axis); losses = tf.math.neg(losses); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index d4c350ef06c..05c5b47e329 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -37,7 +38,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * Hinge hingeLoss = new Hinge(tf); - * Operand<TFloat32> result = hingeLoss.call(labels, predictions); + * Operand<TFloat32> result = hingeLoss.call(Ops tf, labels, predictions); * // produces 1.3f * * @@ -45,57 +46,53 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
      - *    Operand<TFloat32> result = hingeLoss.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = hingeLoss.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 0.55f
        * 
      * *

      Using SUM reduction type: * *

      - *    Hinge hingeLoss = new Hinge(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = hingeLoss.call(labels, predictions);
      + *    Hinge hingeLoss = new Hinge(Reduction.SUM);
      + *    Operand<TFloat32> result = hingeLoss.call(Ops tf, labels, predictions);
        *    // produces 2.6f
        * 
      * *

      Using NONE reduction type: * *

      - *    Hinge hingeLoss = new Hinge(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = hingeLoss.call(labels, predictions);
      + *    Hinge hingeLoss = new Hinge(Reduction.NONE);
      + *    Operand<TFloat32> result = hingeLoss.call(Ops tf, labels, predictions);
        *    // produces [1.1f, 1.5f]
        * 
      */ -public class Hinge extends Loss { +public class Hinge extends AbstractLoss { /** - * Creates a Hinge Loss using {@link Class#getSimpleName()} as the loss name and a Loss Reduction - * of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name and a + * AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public Hinge(Ops tf) { - this(tf, null, Reduction.AUTO); + public Hinge() { + this(null, Reduction.AUTO); } /** - * Creates a Hinge Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public Hinge(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public Hinge(Reduction reduction) { + super(null, reduction); } /** * Creates a Hinge * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public Hinge(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public Hinge(String name, Reduction reduction) { + super(name, reduction); } /** @@ -122,15 +119,16 @@ public Hinge(Ops tf, String name, Reduction reduction) { */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Operand tLabels = cast(tf, labels, predictions.type()); tLabels = LossesHelper.valueCheck( - getTF(), + tf, "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); - Operand losses = Losses.hinge(getTF(), tLabels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + cast(tf, tf.constant(new int[] {-1, 0, 1}), predictions.type())); + Operand losses = Losses.hinge(tf, tLabels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index b1aee1b0656..c9a7d7edcb8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -39,7 +40,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * Huber huberLoss = new Huber(tf); - * Operand<TFloat32> result = huberLoss.call(labels, predictions); + * Operand<TFloat32> result = huberLoss.call(Ops tf, labels, predictions); * // produces 0.155 * * @@ -47,7 +48,7 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
      - *    Operand<TFloat32> result = huberLoss.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = huberLoss.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 0.09f
        * 
      * @@ -55,7 +56,7 @@ * *
        *    Huber huberLoss = new Huber(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = huberLoss.call(labels, predictions);
      + *    Operand<TFloat32> result = huberLoss.call(Ops tf, labels, predictions);
        *    // produces 0.32f
        * 
      * @@ -63,78 +64,74 @@ * *
        *    Huber huberLoss = new Huber(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = huberLoss.call(labels, predictions);
      + *    Operand<TFloat32> result = huberLoss.call(Ops tf, labels, predictions);
        *    // produces [0.18f, 0.13f]
        * 
      * * @see
      Huber loss */ -public class Huber extends Loss { +public class Huber extends AbstractLoss { public static final float DELTA_DEFAULT = 1.0f; private final float delta; /** - * Creates a Huber Loss using {@link Class#getSimpleName()} as the loss name, {@link - * #DELTA_DEFAULT} as the delta and a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Huber AbstractLoss using {@link Class#getSimpleName()} as the loss name, {@link + * #DELTA_DEFAULT} as the delta and a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT} */ - public Huber(Ops tf) { - this(tf, null, DELTA_DEFAULT, Reduction.AUTO); + public Huber() { + this(null, DELTA_DEFAULT, Reduction.AUTO); } /** - * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta and a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT} + * Creates a Huber AbstractLoss using {@link #DELTA_DEFAULT} as the delta and a AbstractLoss + * Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ - public Huber(Ops tf, String name) { - this(tf, name, DELTA_DEFAULT, Reduction.AUTO); + public Huber(String name) { + this(name, DELTA_DEFAULT, Reduction.AUTO); } /** - * Creates a Huber Loss using {@link Class#getSimpleName()} as the loss name and and {@link - * #DELTA_DEFAULT} as the delta + * Creates a Huber AbstractLoss using {@link Class#getSimpleName()} as the loss name and and + * {@link #DELTA_DEFAULT} as the delta * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public Huber(Ops tf, Reduction reduction) { - this(tf, null, DELTA_DEFAULT, reduction); + public Huber(Reduction reduction) { + this(null, DELTA_DEFAULT, reduction); } /** - * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta + * Creates a Huber AbstractLoss using {@link #DELTA_DEFAULT} as the delta * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ - public Huber(Ops tf, String name, Reduction reduction) { - this(tf, name, DELTA_DEFAULT, reduction); + public Huber(String name, Reduction reduction) { + this(name, DELTA_DEFAULT, reduction); } /** - * Creates a Huber Loss + * Creates a Huber AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param delta the point where the Huber loss function changes from quadratic to linear. * @param reduction Type of Reduction to apply to the loss. */ - public Huber(Ops tf, String name, float delta, Reduction reduction) { - super(tf, name, reduction); + public Huber(String name, float delta, Reduction reduction) { + super(name, reduction); this.delta = delta; } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.huber(getTF(), labels, predictions, delta); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.huber(tf, labels, predictions, delta); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 2aa1f72092b..ef5d88539db 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -31,8 +32,8 @@ * tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}}); * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); - * KLDivergence kld = new KLDivergence(tf); - * Operand<TFloat32> result = kld.call(labels, predictions); + * KLDivergence kld = new KLDivergence(); + * Operand<TFloat32> result = kld.call(Ops tf, labels, predictions); * // produces 0.458 * * @@ -40,68 +41,65 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
      - *    Operand<TFloat32> result = kld.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = kld.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 0.366f
        * 
      * *

      Using SUM reduction type: * *

      - *    KLDivergence kld = new KLDivergence(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = kld.call(labels, predictions);
      + *    KLDivergence kld = new KLDivergence(, Reduction.SUM);
      + *    Operand<TFloat32> result = kld.call(Ops tf, labels, predictions);
        *    // produces 0.916f
        * 
      * *

      Using NONE reduction type: * *

      - *    KLDivergence kld = new KLDivergence(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = kld.call(labels, predictions);
      + *    KLDivergence kld = new KLDivergence(, Reduction.NONE);
      + *    Operand<TFloat32> result = kld.call(Ops tf, labels, predictions);
        *    // produces [0.916f, -3.08e-06f]
        * 
      * * @see Kullback?Leibler * divergence */ -public class KLDivergence extends Loss { +public class KLDivergence extends AbstractLoss { /** - * Creates a Kullback Leibler Divergence Loss using {@link Class#getSimpleName()} as the loss name - * and a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Kullback Leibler Divergence AbstractLoss using {@link Class#getSimpleName()} as the + * loss name and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public KLDivergence(Ops tf) { - super(tf); + public KLDivergence() { + super(); } /** - * Creates a Kullback Leibler Divergence Loss Loss using {@link Class#getSimpleName()} as the loss - * name + * Creates a Kullback Leibler Divergence AbstractLoss AbstractLoss using {@link + * Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public KLDivergence(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public KLDivergence(Reduction reduction) { + super(null, reduction); } /** - * Creates a Kullback Leibler Divergence Loss + * Creates a Kullback Leibler Divergence AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public KLDivergence(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public KLDivergence(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.kullbackLeiblerDivergence(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index a11d582e527..02200c3a9e0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -33,7 +34,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {0.f, 0.f}}); * LogCosh logcosh = new LogCosh(tf); - * Operand<TFloat32> result = logcosh.call(labels, predictions); + * Operand<TFloat32> result = logcosh.call(Ops tf, labels, predictions); * // produces 0.108 * * @@ -41,74 +42,71 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
      - *    Operand<TFloat32> result = logcosh.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = logcosh.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 0.087f
        * 
      * *

      Using SUM reduction type: * *

      - *    LogCosh logcosh = new LogCosh(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = logcosh.call(labels, predictions);
      + *    LogCosh logcosh = new LogCosh(Reduction.SUM);
      + *    Operand<TFloat32> result = logcosh.call(Ops tf, labels, predictions);
        *    // produces 0.217f
        * 
      * *

      Using NONE reduction type: * *

      - *    LogCosh logcosh = new LogCosh(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = logcosh.call(labels, predictions);
      + *    LogCosh logcosh = new LogCosh(Reduction.NONE);
      + *    Operand<TFloat32> result = logcosh.call(Ops tf, labels, predictions);
        *    // produces [0.217f, 0f]
        * 
      */ -public class LogCosh extends Loss { +public class LogCosh extends AbstractLoss { /** - * Creates a LogCosh Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a LogCosh AbstractLoss using {@link Class#getSimpleName()} as the loss name and a + * AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public LogCosh(Ops tf) { - this(tf, null, Reduction.AUTO); + public LogCosh() { + this(null, Reduction.AUTO); } /** - * Creates a LogCosh Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * Creates a LogCosh AbstractLoss using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ - public LogCosh(Ops tf, String name) { - this(tf, name, Reduction.AUTO); + public LogCosh(String name) { + this(name, Reduction.AUTO); } /** - * Creates a LogCosh Loss using {@link Class#getSimpleName()} as the loss name + * Creates a LogCosh AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public LogCosh(Ops tf, Reduction reduction) { - this(tf, null, reduction); + public LogCosh(Reduction reduction) { + this(null, reduction); } /** - * Creates a LogCosh Loss + * Creates a LogCosh AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ - public LogCosh(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public LogCosh(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.logCosh(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.logCosh(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index cdd35d28aba..4dd5bce6cde 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,60 +18,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -public abstract class Loss { - public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; - - protected final Ops tf; - protected final Reduction reduction; - - /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops - */ - protected Loss(Ops tf) { - this(tf, null, Reduction.AUTO); - } - - /** - * Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops - * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. - */ - protected Loss(Ops tf, String name) { - this(tf, name, Reduction.AUTO); - } - - /** - * Creates a Loss - * - * @param tf the TensorFlow Ops - * @param name the name of this loss, if null the name will be {@link Class#getSimpleName()}. - * @param reduction Type of Reduction to apply to the loss. - */ - protected Loss(Ops tf, String name, Reduction reduction) { - this.tf = name != null ? tf.withSubScope(name) : tf.withSubScope(getClass().getSimpleName()); - this.reduction = reduction; - } - - /** - * Calculates the loss - * - * @param labels the truth values or labels - * @param predictions the predictions - * @param The data type of the predictions and loss. - * @return the loss - */ - public Operand call( - Operand labels, Operand predictions) { - return call(labels, predictions, null); - } +/** Interface for loss calc ulation */ +@FunctionalInterface +public interface Loss { /** * Generates an Operand that calculates the loss. * + * @param tf the TensorFlow Ops * @param labels the truth values or labels * @param predictions the predictions * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is @@ -84,24 +38,6 @@ public Operand call( * @param The data type of the predictions, sampleWeights and loss. * @return the loss */ - public abstract Operand call( - Operand labels, Operand predictions, Operand sampleWeights); - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - public Ops getTF() { - return tf; - } - - /** - * Gets the loss reduction - * - * @return the loss reduction - */ - public Reduction getReduction() { - return reduction; - } + Operand call( + Ops tf, Operand labels, Operand predictions, Operand sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index 03a3cf70110..d85bdf3561a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}}); * MeanAbsoluteError mae = new MeanAbsoluteError(tf); - * Operand<TFloat32> result = mae.call(labels, predictions); + * Operand<TFloat32> result = mae.call(Ops tf, labels, predictions); * // produces 0.5f * * @@ -40,64 +41,61 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
      - *    Operand<TFloat32> result = mae.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = mae.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 0.25f
        * 
      * *

      Using SUM reduction type: * *

      - *    MeanAbsoluteError mae = new MeanAbsoluteError(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = mae.call(labels, predictions);
      + *    MeanAbsoluteError mae = new MeanAbsoluteError(Reduction.SUM);
      + *    Operand<TFloat32> result = mae.call(Ops tf, labels, predictions);
        *    // produces 1.0f
        * 
      * *

      Using NONE reduction type: * *

      - *    MeanAbsoluteError mae = new MeanAbsoluteError(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = mae.call(labels, predictions);
      + *    MeanAbsoluteError mae = new MeanAbsoluteError(Reduction.NONE);
      + *    Operand<TFloat32> result = mae.call(Ops tf, labels, predictions);
        *    // produces [0.5f, 0.5f]
        * 
      */ -public class MeanAbsoluteError extends Loss { +public class MeanAbsoluteError extends AbstractLoss { /** - * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name and a - * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a MeanAbsoluteError AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public MeanAbsoluteError(Ops tf) { - super(tf); + public MeanAbsoluteError() { + super(); } /** - * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name + * Creates a MeanAbsoluteError AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public MeanAbsoluteError(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public MeanAbsoluteError(Reduction reduction) { + super(null, reduction); } /** * Creates a MeanAbsoluteError * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public MeanAbsoluteError(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.meanAbsoluteError(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 6c5242df4f2..ed5c7d73e2f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}}); * MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(tf); - * Operand<TFloat32> result = mape.call(labels, predictions); + * Operand<TFloat32> result = mape.call(Ops tf, labels, predictions); * // produces 50f * * @@ -40,64 +41,62 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
      - *    Operand<TFloat32> result = mape.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = mape.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 20f
        * 
      * *

      Using SUM reduction type: * *

      - *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = mape.call(labels, predictions);
      + *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(Reduction.SUM);
      + *    Operand<TFloat32> result = mape.call(Ops tf, labels, predictions);
        *    // produces 100.0f
        * 
      * *

      Using NONE reduction type: * *

      - *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = mape.call(labels, predictions);
      + *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(Reduction.NONE);
      + *    Operand<TFloat32> result = mape.call(Ops tf, labels, predictions);
        *    // produces [25f, 75f]
        * 
      */ -public class MeanAbsolutePercentageError extends Loss { +public class MeanAbsolutePercentageError extends AbstractLoss { /** - * Creates a MeanAbsolutePercentageError Loss using {@link Class#getSimpleName()} as the loss name - * and a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a MeanAbsolutePercentageError AbstractLoss using {@link Class#getSimpleName()} as the + * loss name and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public MeanAbsolutePercentageError(Ops tf) { - super(tf); + public MeanAbsolutePercentageError() { + super(); } /** - * Creates a MeanAbsolutePercentageError Loss using {@link Class#getSimpleName()} as the loss name + * Creates a MeanAbsolutePercentageError AbstractLoss using {@link Class#getSimpleName()} as the + * loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public MeanAbsolutePercentageError(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public MeanAbsolutePercentageError(Reduction reduction) { + super(null, reduction); } /** * Creates a MeanAbsolutePercentageError * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public MeanAbsolutePercentageError(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.meanAbsolutePercentageError(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index f975db55c44..c6898e20f20 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}}); * MeanSquaredError mse = new MeanSquaredError(tf); - * Operand<TFloat32> result = mse.call(labels, predictions); + * Operand<TFloat32> result = mse.call(Ops tf, labels, predictions); * // produces 0.5f * * @@ -40,64 +41,61 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
      - *    Operand<TFloat32> result = mse.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = mse.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 0.25f
        * 
      * *

      Using SUM reduction type: * *

      - *    MeanSquaredError mse = new MeanSquaredError(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = mse.call(labels, predictions);
      + *    MeanSquaredError mse = new MeanSquaredError(Reduction.SUM);
      + *    Operand<TFloat32> result = mse.call(Ops tf, labels, predictions);
        *    // produces 1.0f
        * 
      * *

      Using NONE reduction type: * *

      - *    MeanSquaredError mse = new MeanSquaredError(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = mse.call(labels, predictions);
      + *    MeanSquaredError mse = new MeanSquaredError(Reduction.NONE);
      + *    Operand<TFloat32> result = mse.call(Ops tf, labels, predictions);
        *    // produces [0.5f, 0.5f]
        * 
      */ -public class MeanSquaredError extends Loss { +public class MeanSquaredError extends AbstractLoss { /** - * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a MeanSquaredError AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public MeanSquaredError(Ops tf) { - super(tf); + public MeanSquaredError() { + super(); } /** - * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name + * Creates a MeanSquaredError AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public MeanSquaredError(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public MeanSquaredError(Reduction reduction) { + super(null, reduction); } /** * Creates a MeanSquaredError * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public MeanSquaredError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public MeanSquaredError(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.meanSquaredError(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index 11b8e157e90..3d325a98a6a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}}); * MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(tf); - * Operand<TFloat32> result = msle.call(labels, predictions); + * Operand<TFloat32> result = msle.call(Ops tf, labels, predictions); * // produces 0.240f * * @@ -40,64 +41,61 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
      - *    Operand<TFloat32> result = msle.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = msle.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 0.120f
        * 
      * *

      Using SUM reduction type: * *

      - *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = msle.call(labels, predictions);
      + *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(Reduction.SUM);
      + *    Operand<TFloat32> result = msle.call(Ops tf, labels, predictions);
        *    // produces 0.480f
        * 
      * *

      Using NONE reduction type: * *

      - *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = msle.call(labels, predictions);
      + *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(Reduction.NONE);
      + *    Operand<TFloat32> result = msle.call(Ops tf, labels, predictions);
        *    // produces [0.240f, 0.240f]
        * 
      */ -public class MeanSquaredLogarithmicError extends Loss { +public class MeanSquaredLogarithmicError extends AbstractLoss { /** - * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a MeanSquaredError AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public MeanSquaredLogarithmicError(Ops tf) { - super(tf); + public MeanSquaredLogarithmicError() { + super(); } /** - * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name + * Creates a MeanSquaredError AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public MeanSquaredLogarithmicError(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public MeanSquaredLogarithmicError(Reduction reduction) { + super(null, reduction); } /** * Creates a MeanSquaredError * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public MeanSquaredLogarithmicError(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.meanSquaredLogarithmicError(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index 78324acf8a5..a6eb29b7109 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {0.f, 0.f}}); * Poisson poissonLoss = new Poisson(tf); - * Operand<TFloat32> result = poissonLoss.call(labels, predictions); + * Operand<TFloat32> result = poissonLoss.call(Ops tf, labels, predictions); * // produces 0.5f * * @@ -40,74 +41,71 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
      - *    Operand<TFloat32> result = poissonLoss.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = poissonLoss.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 0.4f
        * 
      * *

      Using SUM reduction type: * *

      - *    Poisson poissonLoss = new Poisson(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = poissonLoss.call(labels, predictions);
      + *    Poisson poissonLoss = new Poisson(Reduction.SUM);
      + *    Operand<TFloat32> result = poissonLoss.call(Ops tf, labels, predictions);
        *    // produces 0.999f
        * 
      * *

      Using NONE reduction type: * *

      - *    Poisson poissonLoss = new Poisson(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = poissonLoss.call(labels, predictions);
      + *    Poisson poissonLoss = new Poisson(Reduction.NONE);
      + *    Operand<TFloat32> result = poissonLoss.call(Ops tf, labels, predictions);
        *    // produces [0.999f, 0f]
        * 
      */ -public class Poisson extends Loss { +public class Poisson extends AbstractLoss { /** - * Creates a Poisson Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Poisson AbstractLoss using {@link Class#getSimpleName()} as the loss name and a + * AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public Poisson(Ops tf) { - this(tf, null, Reduction.AUTO); + public Poisson() { + this(null, Reduction.AUTO); } /** - * Creates a Poisson Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * Creates a Poisson AbstractLoss using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ - public Poisson(Ops tf, String name) { - this(tf, name, Reduction.AUTO); + public Poisson(String name) { + this(name, Reduction.AUTO); } /** - * Creates a Poisson Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Poisson AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public Poisson(Ops tf, Reduction reduction) { - this(tf, null, reduction); + public Poisson(Reduction reduction) { + this(null, reduction); } /** - * Creates a Poisson Loss + * Creates a Poisson AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ - public Poisson(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public Poisson(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.poisson(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.poisson(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java index 87ea43c6c3a..e40ec6d6ebb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java @@ -15,7 +15,7 @@ package org.tensorflow.framework.losses; /** - * Type of Loss Reduction + * Type of AbstractLoss Reduction * *

      {@link #AUTO} indicates that the reduction option will be determined by the usage context. For * almost all cases this defaults to {@link #SUM_OVER_BATCH_SIZE}. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index d04cc67d5d9..291a91894b0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -43,7 +44,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.05f, 0.95f, 0f}, {0.1f, 0.8f, 0.1f}}); * SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(tf); - * Operand<TFloat32> result = sparseCCE.call(labels, predictions); + * Operand<TFloat32> result = sparseCCE.call(Ops tf, labels, predictions); * // produces 1.177f * * @@ -51,27 +52,27 @@ * *

        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.3f, 0.7f});
      - *    Operand<TFloat32> result = sparseCCE.call(labels, predictions, sampleWeight);
      + *    Operand<TFloat32> result = sparseCCE.call(Ops tf, labels, predictions, sampleWeight);
        *    // produces 0.814f
        * 
      * *

      Using SUM reduction type: * *

      - *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = sparseCCE.call(labels, predictions);
      + *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(Reduction.SUM);
      + *    Operand<TFloat32> result = sparseCCE.call(Ops tf, labels, predictions);
        *    // produces 2.354f
        * 
      * *

      Using NONE reduction type: * *

      - *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = sparseCCE.call(labels, predictions);
      + *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(Reduction.NONE);
      + *    Operand<TFloat32> result = sparseCCE.call(Ops tf, labels, predictions);
        *    // produces [0.0513f, 2.303f]
        * 
      */ -public class SparseCategoricalCrossentropy extends Loss { +public class SparseCategoricalCrossentropy extends AbstractLoss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final int AXIS_DEFAULT = -1; @@ -80,24 +81,23 @@ public class SparseCategoricalCrossentropy extends Loss { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link + * name, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and fromLogits={@link * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops */ - public SparseCategoricalCrossentropy(Ops tf) { - this(tf, null, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); + public SparseCategoricalCrossentropy() { + this(null, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); } /** - * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy loss using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * - * @param tf the TensorFlow Ops * @param name the name of this loss function */ - public SparseCategoricalCrossentropy(Ops tf, String name) { - this(tf, name, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(String name) { + this(name, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); } /** @@ -107,8 +107,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name) { * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to loss. */ - public SparseCategoricalCrossentropy(Ops tf, Reduction reduction) { - this(tf, null, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(Reduction reduction) { + this(null, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); } /** @@ -119,32 +119,32 @@ public SparseCategoricalCrossentropy(Ops tf, Reduction reduction) { * @param name the name of this loss function * @param reduction Type of Reduction to apply to loss. */ - public SparseCategoricalCrossentropy(Ops tf, String name, Reduction reduction) { - this(tf, name, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(String name, Reduction reduction) { + this(name, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); } /** - * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public SparseCategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { - this(tf, name, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(String name, boolean fromLogits) { + this(name, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); } /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link + * name, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} and fromLogits={@link * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits) { - this(tf, null, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(boolean fromLogits) { + this(null, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); } /** @@ -155,8 +155,8 @@ public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits) { * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param reduction Type of Reduction to apply to loss. */ - public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits, Reduction reduction) { - this(tf, null, fromLogits, reduction, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(boolean fromLogits, Reduction reduction) { + this(null, fromLogits, reduction, AXIS_DEFAULT); } /** @@ -170,8 +170,8 @@ public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits, Reduction reduc * and axis=1 corresponds to data format 'Channels First'. */ public SparseCategoricalCrossentropy( - Ops tf, String name, boolean fromLogits, Reduction reduction, int axis) { - super(tf, name, reduction); + String name, boolean fromLogits, Reduction reduction, int axis) { + super(name, reduction); this.fromLogits = fromLogits; this.axis = axis; } @@ -199,23 +199,24 @@ public SparseCategoricalCrossentropy( */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = LossesHelper.rangeCheck( - getTF(), + tf, "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + cast(tf, tf.constant(0), predictions.type()), + cast(tf, tf.constant(1), predictions.type())); } else { lPredictions = predictions; } Operand losses = - Losses.sparseCategoricalCrossentropy(getTF(), labels, lPredictions, fromLogits, axis); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Losses.sparseCategoricalCrossentropy(tf, labels, lPredictions, fromLogits, axis); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index dadbdb3b95e..c804b463984 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -37,7 +38,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * SquaredHinge squaredHinge = new SquaredHinge(tf); - * Operand<TFloat32> result = squaredHinge.call(labels, predictions); + * Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions); * // produces 1.86f * * @@ -45,7 +46,7 @@ * *
        *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
      - *    Operand<TFloat32> result = squaredHinge.call(labels, predictions,
      + *    Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions,
        *                                                  sampleWeight);
        *    // produces 0.73f
        * 
      @@ -53,50 +54,46 @@ *

      Using SUM reduction type: * *

      - *    SquaredHinge squaredHinge = new SquaredHinge(tf, Reduction.SUM);
      - *    Operand<TFloat32> result = squaredHinge.call(labels, predictions);
      + *    SquaredHinge squaredHinge = new SquaredHinge(Reduction.SUM);
      + *    Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions);
        *    // produces 3.72f
        * 
      * *

      Using NONE reduction type: * *

      - *    SquaredHinge squaredHinge = new SquaredHinge(tf, Reduction.NONE);
      - *    Operand<TFloat32> result = squaredHinge.call(labels, predictions);
      + *    SquaredHinge squaredHinge = new SquaredHinge(Reduction.NONE);
      + *    Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions);
        *    // produces [1.46f, 2.26f]
        * 
      */ -public class SquaredHinge extends Loss { +public class SquaredHinge extends AbstractLoss { /** - * Creates a Squared Hinge Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Squared Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name and a + * AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public SquaredHinge(Ops tf) { - super(tf); + public SquaredHinge() { + super(); } /** - * Creates a Squared Hinge Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Squared Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public SquaredHinge(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public SquaredHinge(Reduction reduction) { + super(null, reduction); } /** * Creates a Squared Hinge * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public SquaredHinge(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public SquaredHinge(String name, Reduction reduction) { + super(name, reduction); } /** @@ -123,19 +120,17 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + @SuppressWarnings("unchecked") - Operand tLabels = - predictions.type() == labels.type() - ? (Operand) labels - : cast(tf, labels, predictions.type()); + Operand tLabels = cast(tf, labels, predictions.type()); tLabels = LossesHelper.valueCheck( - getTF(), + tf, "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); - Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + cast(tf, tf.constant(new int[] {-1, 0, 1}), predictions.type())); + Operand losses = Losses.squaredHinge(tf, tLabels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/AbstractLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/AbstractLoss.java new file mode 100644 index 00000000000..9534f6fe3ad --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/AbstractLoss.java @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.losses.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.framework.losses.Reduction; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +public abstract class AbstractLoss implements Loss { + public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; + + protected final Reduction reduction; + private final String name; + + /** + * Creates a AbstractLoss using {@link Class#getSimpleName()} as the name and a AbstractLoss + * Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} + */ + protected AbstractLoss() { + this(null, Reduction.AUTO); + } + + /** + * Creates a AbstractLoss using a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} + * + * @param name the name of this AbstractLoss, if null the name will be {@link + * Class#getSimpleName()}. + */ + protected AbstractLoss(String name) { + this(name, Reduction.AUTO); + } + + /** + * Creates a AbstractLoss + * + * @param name the name of this loss, if null the name will be {@link Class#getSimpleName()}. + * @param reduction Type of Reduction to apply to the loss. + */ + protected AbstractLoss(String name, Reduction reduction) { + this.name = name == null ? getClass().getSimpleName() : name; + this.reduction = reduction; + } + + /** + * Calculates the loss + * + * @param tf the TensorFlow Ops + * @param labels the truth values or labels + * @param predictions the predictions + * @param The data type of the predictions and loss. + * @return the loss + */ + public Operand call( + Ops tf, Operand labels, Operand predictions) { + return call(tf, labels, predictions, null); + } + + /** + * Gets the loss reduction + * + * @return the loss reduction + */ + public Reduction getReduction() { + return reduction; + } + + /** + * Gets the name for this loss + * + * @return the name for this loss + */ + public String getName() { + return name; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index bc5047d5855..69cb2ee0dfe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -40,26 +40,26 @@ /** * Metric that computes the approximate AUC (Area under the curve) via a Riemann sum. * - *

      This metric creates four local variables, {@code truePositives}, {@code trueNegatives - * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the - * AUC. To discretize the AUC curve, a linearly spaced set of thresholds is used to compute pairs of - * recall and precision values. The area under the ROC-curve is therefore computed using the height - * of the recall values by the false positive rate, while the area under the PR-curve is the - * computed using the height of the precision values by the recall. + *

      This metric creates four local variables, {@code truePositives}, {@code trueNegatives }, + * {@code falsePositives} and {@code falseNegatives} that are used to compute the AUC. To discretize + * the AUC curve, a linearly spaced set of thresholds is used to compute pairs of recall and + * precision values. The area under the ROC-curve is therefore computed using the height of the + * recall values by the false positive rate, while the area under the PR-curve is the computed using + * the height of the precision values by the recall. * - *

      This value is ultimately returned as {@code auc}, an idempotent operation that computes - * the area under a discretized curve of precision versus recall values (computed using the + *

      This value is ultimately returned as {@code auc}, an idempotent operation that computes the + * area under a discretized curve of precision versus recall values (computed using the * aforementioned variables). The {@code numThresholds} variable controls the degree of * discretization with larger numbers of thresholds more closely approximating the true AUC. The - * quality of the approximation may vary dramatically depending on {@code numThresholds}. The - * {@code thresholds} parameter can be used to manually specify thresholds which split the - * predictions more evenly. + * quality of the approximation may vary dramatically depending on {@code numThresholds}. The {@code + * thresholds} parameter can be used to manually specify thresholds which split the predictions more + * evenly. * - *

      For best results, {@code predictions} should be distributed approximately uniformly in - * the range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor - * if this is not the case. Setting {@code summationMethod} to {@code minoring} or {@code - * majoring} can help quantify the error in the approximation by providing lower or upper - * bound estimate of the AUC. + *

      For best results, {@code predictions} should be distributed approximately uniformly in the + * range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor if + * this is not the case. Setting {@code summationMethod} to {@code minoring} or {@code majoring} can + * help quantify the error in the approximation by providing lower or upper bound estimate of the + * AUC. * *

      Usage:
      * @@ -155,8 +155,8 @@ public class AUC extends Metric { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link #DEFAULT_NUM_THRESHOLDS} for the numThresholds, {@link AUCCurve#ROC} for the curve type, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for - * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed @@ -180,8 +180,8 @@ public AUC(Ops tf, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NUM_THRESHOLDS} for the * numThresholds, {@link AUCCurve#ROC} for the curve type, {@link - * AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, - * {@code false} for multiLabel, and {@code null} for labelWeights. + * AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, {@code + * false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -206,8 +206,8 @@ public AUC(Ops tf, String name, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the - * summation method, {@code null} for thresholds, {@code false} for multiLabel, and - * {@code null} for labelWeights. + * summation method, {@code null} for thresholds, {@code false} for multiLabel, and {@code null} + * for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -233,8 +233,8 @@ public AUC(Ops tf, int numThresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the - * summation method, {@code null} for numThresholds, {@code false} for multiLabel, and - * {@code null} for labelWeights. + * summation method, {@code null} for numThresholds, {@code false} for multiLabel, and {@code + * null} for labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -259,8 +259,8 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric. using {@link AUCCurve#ROC} for the curve type, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for - * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -314,8 +314,8 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link AUCSummationMethod#INTERPOLATION} for - * the summation method, {@code null} for thresholds, {@code false} for multiLabel, and - * {@code null} for labelWeights. + * the summation method, {@code null} for thresholds, {@code false} for multiLabel, and {@code + * null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -372,8 +372,8 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for - * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -400,8 +400,8 @@ public AUC(Ops tf, int numThresholds, AUCCurve curve, long seed, Class type) /** * Creates an AUC (Area under the curve) metric using {@code null} for numThresholds, {@link - * AUCSummationMethod#INTERPOLATION} for the summation method, {@code false} for multiLabel, - * and {@code null} for labelWeights. + * AUCSummationMethod#INTERPOLATION} for the summation method, {@code false} for multiLabel, and + * {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -428,8 +428,7 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) /** * Creates an AUC (Area under the curve) metric. using {@link #DEFAULT_NAME} for the metric name,, - * {@code null} for thresholds, {@code false} for multiLabel, and {@code null} for - * labelWeights. + * {@code null} for thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -453,8 +452,8 @@ public AUC( /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, - * {@code null} for numThresholds, {@code false} for multiLabel, and {@code null} - * for labelWeights. + * {@code null} for numThresholds, {@code false} for multiLabel, and {@code null} for + * labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -487,8 +486,8 @@ public AUC( } /** - * Creates an AUC (Area under the curve) metric. using {@code null} for thresholds, {@code - * false} for multiLabel, and {@code null} for labelWeights. + * Creates an AUC (Area under the curve) metric. using {@code null} for thresholds, {@code false} + * for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -513,8 +512,8 @@ public AUC( } /** - * Creates an AUC (Area under the curve) metric. using {@code null} for the numThresholds, - * {@code false} for multiLabel, and {@code null} for labelWeights. + * Creates an AUC (Area under the curve) metric. using {@code null} for the numThresholds, {@code + * false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -560,16 +559,16 @@ public AUC( * @param summationMethod Specifies the Riemann summation method used * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. This method - * automatically brackets the provided {@code thresholds} with a (-{@link #EPSILON}) - * below and a (1 + {@link #EPSILON}) above. + * automatically brackets the provided {@code thresholds} with a (-{@link #EPSILON}) below and + * a (1 + {@link #EPSILON}) above. * @param multiLabel boolean indicating whether multilabel data should be treated as such, wherein * AUC is computed separately for each label and then averaged across labels, or (when false) * if the data should be flattened into a single label before AUC computation. In the latter * case, when multilabel data is passed to AUC, each label-prediction pair is treated as an * individual data point. Should be set to {@code false} for multi-class data. * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When {@code - * multiLabel} is true, the weights are applied to the individual label AUCs when they - * are averaged to produce the multi-label AUC. When it's false, they are used to weight the + * multiLabel} is true, the weights are applied to the individual label AUCs when they are + * averaged to produce the multi-label AUC. When it's false, they are used to weight the * individual label predictions in computing the confusion matrix on the flattened data. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -684,8 +683,8 @@ private Map> build(Shape shape) { } // Create metric variables - Zeros zeros = new Zeros<>(tf); - Operand zero = zeros.call(tf.constant(variableShape), type); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(tf, tf.constant(variableShape), type); if (truePositives == null) { truePositives = tf.withName(getTruePositivesName()).variable(zero); initializers.put(ConfusionMatrixEnum.TRUE_POSITIVES, tf.assign(truePositives, zero)); @@ -715,8 +714,8 @@ private Map> build(Shape shape) { * * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more class * dimensions, and L1 is a potential extra dimension of size 1 that would be squeezed. Will be - * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} {@code != null - * }, then Cx must be a single dimension. + * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} {@code != null }, + * then Cx must be a single dimension. * @param predictions the predictions shape (N, Cx, P1?). Will be cast to {@code T}. * @param sampleWeights sample weights to be applied to values, may be null. Will be cast to * {@code }. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java index 516d6c91ba6..b8ec681cbfc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java @@ -29,12 +29,10 @@ * Metric that calculates how often predictions equals labels. * *

      This metric creates two local variables, total and count that are used to compute the - * frequency with which {@code predictions} matches {@code labels}. This frequency is - * ultimately returned as binary accuracy: an idempotent operation that simply divides total by - * count. + * frequency with which {@code predictions} matches {@code labels}. This frequency is ultimately + * returned as binary accuracy: an idempotent operation that simply divides total by count. * - *

      If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask - * values. + *

      If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask values. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java index 0e41699e165..a03677efd43 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -26,12 +26,10 @@ * Metric that calculates how often predictions matches binary labels. * *

      This metric creates two local variables, total and count that are used to compute the - * frequency with which {@code predictions} matches {@code labels}. This frequency is - * ultimately returned as binary accuracy: an idempotent operation that simply divides total by - * count. + * frequency with which {@code predictions} matches {@code labels}. This frequency is ultimately + * returned as binary accuracy: an idempotent operation that simply divides total by count. * - *

      If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask - * values. + *

      If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask values. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java index dece2d1cd50..0cd90325e32 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -27,18 +27,17 @@ /** * Metric that calculates how often predictions matches one-hot labels. * - *

      You can provide {@code logits} of classes as {@code predictions}, since argmax of - * {@code logits} and probabilities are same. + *

      You can provide {@code logits} of classes as {@code predictions}, since argmax of {@code + * logits} and probabilities are same. * - *

      This metric creates two local variables, {@code total} and {@code count} that are - * used to compute the frequency with which {@code predictions} matches {@code labels}. - * This frequency is ultimately returned as categorical accuracy: an idempotent operation that - * simply divides total by count. + *

      This metric creates two local variables, {@code total} and {@code count} that are used to + * compute the frequency with which {@code predictions} matches {@code labels}. This frequency is + * ultimately returned as categorical accuracy: an idempotent operation that simply divides total by + * count. * - *

      {@code predictions} and {@code labels} should be passed in as vectors of - * probabilities, rather than as labels. If necessary, use {@link - * org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, Operand, OneHot.Options...)} to expand - * {@code labels} as a vector. + *

      {@code predictions} and {@code labels} should be passed in as vectors of probabilities, rather + * than as labels. If necessary, use {@link org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, + * Operand, OneHot.Options...)} to expand {@code labels} as a vector. * *

      If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index 58aa51f664c..4a32981aeeb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -29,8 +29,7 @@ * *

      This is the crossentropy metric class to be used when there are multiple label classes (2 or * more). The labels should be given as a one_hot representation. eg., When labels values are {@code - * [2, 0, 1]}, the labels Operand contains = {@code [[0, 0, 1], [1, 0, 0], [0, 1, 0]] - * }. + * [2, 0, 1]}, the labels Operand contains = {@code [[0, 0, 1], [1, 0, 0], [0, 1, 0]] }. * * @param The data type for the metric result */ @@ -52,9 +51,9 @@ public class CategoricalCrossentropy extends MeanMetricWrappe * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to * a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, - * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} - * means that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 - * } for label {@code 1} + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} means + * that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 } for label + * {@code 1} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result @@ -73,13 +72,12 @@ public CategoricalCrossentropy( * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a * probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, - * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} - * means that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 - * } for label {@code 1} + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} means + * that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 } for label + * {@code 1} * @param axis Int specifying the channels axis. {@code axis={@link Losses#CHANNELS_LAST}} - * corresponds to data format {@code channels_last}, and {@code - * axis={@link Losses#CHANNELS_FIRST}} corresponds to data format {@code - * channels_first}. + * corresponds to data format {@code channels_last}, and {@code axis={@link + * Losses#CHANNELS_FIRST}} corresponds to data format {@code channels_first}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java index 3db7fffc2e9..9f957ee6c17 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of false negatives. * - *

      If {@code sampleWeights} is given, calculates the sum of the weights of false negatives. - * This metric creates one local variable, {@code accumulator} that is used to keep track of - * the number of false negatives. + *

      If {@code sampleWeights} is given, calculates the sum of the weights of false negatives. This + * metric creates one local variable, {@code accumulator} that is used to keep track of the number + * of false negatives. * - *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code - * sampleWeights} of 0 to mask values. + *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code sampleWeights} of 0 + * to mask values. * * @param The data type for the metric result */ @@ -50,10 +50,10 @@ public FalseNegatives(Ops tf, long seed, Class type) { * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -66,10 +66,10 @@ public FalseNegatives(Ops tf, float threshold, long seed, Class type) { * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -96,10 +96,10 @@ public FalseNegatives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -113,10 +113,10 @@ public FalseNegatives(Ops tf, String name, float threshold, long seed, Class * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java index 551529b6179..a3d585dea0f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of false positives. * - *

      If {@code sampleWeights} is given, calculates the sum of the weights of false positives. - * This metric creates one local variable, {@code accumulator} that is used to keep track of - * the number of false positives. + *

      If {@code sampleWeights} is given, calculates the sum of the weights of false positives. This + * metric creates one local variable, {@code accumulator} that is used to keep track of the number + * of false positives. * - *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code - * sampleWeights} of 0 to mask values. + *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code sampleWeights} of 0 + * to mask values. * * @param The data type for the metric result */ @@ -50,10 +50,10 @@ public FalsePositives(Ops tf, long seed, Class type) { * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -66,10 +66,10 @@ public FalsePositives(Ops tf, float threshold, long seed, Class type) { * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -96,10 +96,10 @@ public FalsePositives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -113,10 +113,10 @@ public FalsePositives(Ops tf, String name, float threshold, long seed, Class * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 22baab3d6cb..04f4deb81cf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -93,11 +93,15 @@ private void init() { Shape variableShape = Shape.of(numClasses, numClasses); if (totalConfusionMatrix == null) { - Zeros zeros = new Zeros<>(getTF()); + Zeros zeros = new Zeros<>(); totalConfusionMatrix = - getTF().withName(totalCMName).variable(zeros.call(getTF().constant(variableShape), type)); + getTF() + .withName(totalCMName) + .variable(zeros.call(getTF(), getTF().constant(variableShape), type)); initializer = - getTF().assign(totalConfusionMatrix, zeros.call(getTF().constant(variableShape), type)); + getTF() + .assign( + totalConfusionMatrix, zeros.call(getTF(), getTF().constant(variableShape), type)); } } @@ -124,8 +128,8 @@ public Assign getInitializer() { * @param sampleWeights Optional weighting of each example. Defaults to 1, if null. Rank is either * 0, or the same rank as labels, and must be broadcastable to labels. * @return the Operands that updates totalConfusionMatrix variable - * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} labels rank, - * and if the predictions size is not equal to the labels size + * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} + * labels rank, and if the predictions size is not equal to the labels size */ @Override public List updateStateList( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index acf28f5b2cc..8d92b97ec5f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -28,13 +28,12 @@ /** * Computes the mean relative error by normalizing with the given values. * - *

      This metric creates two local variables, {@code total} and {@code count} that are - * used to compute the mean relative error. This is weighted by {@code sampleWeight}, and it is - * ultimately returned as mean relative error: an idempotent operation that simply divides total by - * count. + *

      This metric creates two local variables, {@code total} and {@code count} that are used to + * compute the mean relative error. This is weighted by {@code sampleWeight}, and it is ultimately + * returned as mean relative error: an idempotent operation that simply divides total by count. * - *

      If {@code sampleWeight} is {@code null}, weights default to 1. Use {@code sampleWeight} - * of 0 to mask values. + *

      If {@code sampleWeight} is {@code null}, weights default to 1. Use {@code sampleWeight} of 0 + * to mask values. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java index d88d7a4c1b4..583d9b2dde7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -85,8 +85,8 @@ public MeanTensor(Ops tf, String name, long seed, Class type) { private boolean init(Shape shape) { if (!initialized) { this.shape = shape; - Zeros zeros = new Zeros<>(getTF()); - Operand zero = zeros.call(getTF().constant(shape), type); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(getTF(), getTF().constant(shape), type); if (total == null) { total = getTF().withName(totalName).variable(zero); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index 3812e799b75..f81b32e8d76 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -36,22 +36,22 @@ /** * Computes the precision of the predictions with respect to the labels. * - *

      The metric creates two local variables, {@code truePositives} and {@code falsePositives - * } that are used to compute the precision. This value is ultimately returned as precision, - * an idempotent operation that simply divides {@code truePositives} by the sum of {@code - * truePositives} and {@code falsePositives}. + *

      The metric creates two local variables, {@code truePositives} and {@code falsePositives } that + * are used to compute the precision. This value is ultimately returned as precision, an idempotent + * operation that simply divides {@code truePositives} by the sum of {@code truePositives} and + * {@code falsePositives}. * - *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of - * 0 to mask values. + *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of 0 to mask + * values. * - *

      If {@code topK} is set, the metric calculates precision as how often on average a class - * among the top-k classes with the highest predicted values of a batch entry is correct and can be - * found in the label for that entry. + *

      If {@code topK} is set, the metric calculates precision as how often on average a class among + * the top-k classes with the highest predicted values of a batch entry is correct and can be found + * in the label for that entry. * *

      If {@code classId} is specified, the metric calculates precision by considering only the - * entries in the batch for which {@code classId} is above the {@code thresholds} and/or - * in the top-k highest predictions, and computing the fraction of them for which {@code classId - * } is indeed a correct label. + * entries in the batch for which {@code classId} is above the {@code thresholds} and/or in the + * top-k highest predictions, and computing the fraction of them for which {@code classId } is + * indeed a correct label. * * @param The data type for the metric result */ @@ -103,10 +103,9 @@ public Precision(Ops tf, String name, long seed, Class type) { * values. * * @param tf the TensorFlow Ops - * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is - * compared with prediction values to determine the truth value of predictions (i.e., above - * the threshold is true, below is false). One metric value is generated for each threshold - * value. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated for each threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -138,10 +137,9 @@ public Precision(Ops tf, float[] thresholds, long seed, Class type) { * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is - * compared with prediction values to determine the truth value of predictions (i.e., above - * the threshold is true, below is false). One metric value is generated for each threshold - * value. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated for each threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -172,10 +170,9 @@ public Precision(Ops tf, String name, float[] thresholds, long seed, Class ty * Creates a Precision Metric with a name of {@link Class#getSimpleName()} * * @param tf the TensorFlow Ops - * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is - * compared with prediction values to determine the truth value of predictions (i.e., above - * the threshold is true, below is false). One metric value is generated for each threshold - * value. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated for each threshold value. * @param topK An optional value specifying the top-k predictions to consider when calculating * precision. * @param classId Optional Integer class ID for which we want binary metrics. This must be in the @@ -216,10 +213,9 @@ public Precision( * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is - * compared with prediction values to determine the truth value of predictions (i.e., above - * the threshold is true, below is false). One metric value is generated for each threshold - * value. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated for each threshold value. * @param topK An optional value specifying the top-k predictions to consider when calculating * precision. * @param classId Optional Integer class ID for which we want binary metrics. This must be in the @@ -280,17 +276,15 @@ public Precision( /** Initializes the variables */ private void init() { Ops tf = getTF(); - Zeros zeros = new Zeros<>(tf); - Operand zero = zeros.call(tf.constant(Shape.of(thresholds.length)), type); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(tf, tf.constant(Shape.of(thresholds.length)), type); if (this.truePositives == null) { this.truePositives = tf.withName(truePositivesName).variable(zero); initializers.add(tf.assign(truePositives, zero)); } if (this.falsePositives == null) { - this.falsePositives = - tf.withName(falsePositivesName) - .variable(zero); + this.falsePositives = tf.withName(falsePositivesName).variable(zero); initializers.add(tf.assign(falsePositives, zero)); } } @@ -340,11 +334,12 @@ public List updateStateList( public Operand result() { Ops tf = getTF(); Operand result = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); - return thresholds.length == 1 - ? tf.reshape(tf.slice( - result, - tf.expandDims(tf.constant(0), tf.constant(0)), - tf.expandDims(tf.constant(1), tf.constant(0))), + return thresholds.length == 1 + ? tf.reshape( + tf.slice( + result, + tf.expandDims(tf.constant(0), tf.constant(0)), + tf.expandDims(tf.constant(1), tf.constant(0))), tf.constant(Shape.scalar())) : result; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java index 5f5f9b47a10..0bb49378f5b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -29,8 +29,8 @@ * falseNegatives that are used to compute the precision at the given recall. The threshold for the * given recall value is computed and used to evaluate the corresponding precision. * - *

      If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of - * 0 to mask values. + *

      If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of 0 to mask + * values. * * @param The data type for the metric result */ @@ -115,8 +115,7 @@ public PrecisionAtRecall( public Operand result() { Ops tf = getTF(); - Operand div = - tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); + Operand div = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); Operand sub = tf.math.sub(div, cast(tf, tf.constant(recall), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); minIndex = tf.expandDims(minIndex, tf.constant(0)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java index 3886ec050b0..2780add994f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -36,20 +36,20 @@ /** * Computes the recall of the predictions with respect to the labels. * - *

      This metric creates two local variables, {@code truePositives} and {@code falseNegatives - * }, that are used to compute the recall. This value is ultimately returned as recall, an - * idempotent operation that simply divides {@code truePositives} by the sum of {@code - * truePositives} and {@code falseNegatives}. + *

      This metric creates two local variables, {@code truePositives} and {@code falseNegatives }, + * that are used to compute the recall. This value is ultimately returned as recall, an idempotent + * operation that simply divides {@code truePositives} by the sum of {@code truePositives} and + * {@code falseNegatives}. * - *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of - * 0 to mask values. + *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of 0 to mask + * values. * - *

      If {@code topK} is set, the metric calculates recall as how often on average a class - * among the labels of a batch entry is in the top-k predictions. + *

      If {@code topK} is set, the metric calculates recall as how often on average a class among the + * labels of a batch entry is in the top-k predictions. * - *

      If {@code classId} is specified, the metric calculates recall by considering only the - * entries in the batch for which {@code classId} is in the label, and computing the fraction - * of them for which {@code classId} is above the threshold and/or in the top-k predictions. + *

      If {@code classId} is specified, the metric calculates recall by considering only the entries + * in the batch for which {@code classId} is in the label, and computing the fraction of them for + * which {@code classId} is above the threshold and/or in the top-k predictions. * * @param The data type for the metric result */ @@ -305,8 +305,8 @@ public Recall( /** Initializes the Variables */ private void init() { Ops tf = getTF(); - Zeros zeros = new Zeros<>(tf); - Operand zero = zeros.call(tf.constant(Shape.of(this.thresholds.length)), type); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(tf, tf.constant(Shape.of(this.thresholds.length)), type); if (truePositives == null) { truePositives = tf.withName(truePositivesName).variable(zero); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java index a3fc2f77b7f..e54def48fce 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -34,8 +34,8 @@ * falseNegatives that are used to compute the recall at the given precision. The threshold for the * given precision value is computed and used to evaluate the corresponding recall. * - *

      If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of - * 0 to mask values. + *

      If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of 0 to mask + * values. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java index 3886428425b..0d140eb96b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -27,8 +27,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes root mean squared error metric between {@code labels} and {@code predictions} - * . + * Computes root mean squared error metric between {@code labels} and {@code predictions} . * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java index 29c0504b823..23a529ae1bb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -25,19 +25,19 @@ /** * Computes best sensitivity where sensitivity is >= specified value. * - *

      {@code Sensitivity} measures the proportion of actual positives that are correctly - * identified as such {@code (tp / (tp + fn))}. + *

      {@code Sensitivity} measures the proportion of actual positives that are correctly identified + * as such {@code (tp / (tp + fn))}. * - *

      {@code Specificity} measures the proportion of actual negatives that are correctly - * identified as such {@code (tn / (tn + fp))}. + *

      {@code Specificity} measures the proportion of actual negatives that are correctly identified + * as such {@code (tn / (tn + fp))}. * - *

      This metric creates four local variables, {@code truePositives}, {@code trueNegatives - * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the - * sensitivity at the given specificity. The threshold for the given specificity value is computed - * and used to evaluate the corresponding sensitivity. + *

      This metric creates four local variables, {@code truePositives}, {@code trueNegatives }, + * {@code falsePositives} and {@code falseNegatives} that are used to compute the sensitivity at the + * given specificity. The threshold for the given specificity value is computed and used to evaluate + * the corresponding sensitivity. * - *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of - * 0 to mask values. + *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of 0 to mask + * values. * * @see Additional information * about specificity and sensitivity diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 5294f798044..1d017ddf8fb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -35,9 +35,9 @@ * probabilities are same. * *

      This metric creates two local variables, `total` and `count` that are used to compute the - * frequency with which {@code predictions} matches {@code labels}. This frequency is - * ultimately returned as `sparse categorical accuracy`: an idempotent operation that simply divides - * `total` by `count`. + * frequency with which {@code predictions} matches {@code labels}. This frequency is ultimately + * returned as `sparse categorical accuracy`: an idempotent operation that simply divides `total` by + * `count`. * *

      If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values.' * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java index 2cb7e54eba0..95d46c8fd06 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -24,19 +24,19 @@ /** * Computes best specificity where sensitivity is >= specified value. {@code Sensitivity} - * measures the proportion of actual positives that are correctly identified as such {@code - * (tp / (tp + fn))}. + * measures the proportion of actual positives that are correctly identified as such {@code (tp / + * (tp + fn))}. * - *

      {@code Specificity} measures the proportion of actual negatives that are correctly - * identified as such {@code (tn / (tn + fp))}. + *

      {@code Specificity} measures the proportion of actual negatives that are correctly identified + * as such {@code (tn / (tn + fp))}. * - *

      This metric creates four local variables, {@code truePositives}, {@code trueNegatives - * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the - * specificity at the given sensitivity. The threshold for the given sensitivity value is computed - * and used to evaluate the corresponding specificity. + *

      This metric creates four local variables, {@code truePositives}, {@code trueNegatives }, + * {@code falsePositives} and {@code falseNegatives} that are used to compute the specificity at the + * given sensitivity. The threshold for the given sensitivity value is computed and used to evaluate + * the corresponding specificity. * - *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of - * 0 to mask values. + *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of 0 to mask + * values. * * @see Additional information * about specificity and sensitivity diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java index 637ca6cdd05..bcb1d7b9a36 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java @@ -21,11 +21,11 @@ /** * Computes the (weighted) sum of the given values. * - *

      For example, if values is {@code [1, 3, 5, 7]} then the sum is {@code 16}. If the - * weights were specified as {@code [1, 1, 0, 0]}, then the sum would be {@code 4.} + *

      For example, if values is {@code [1, 3, 5, 7]} then the sum is {@code 16}. If the weights were + * specified as {@code [1, 1, 0, 0]}, then the sum would be {@code 4.} * - *

      This metric creates one variable, {@code total}, that is used to compute the sum of - * values. This is ultimately returned as sum. + *

      This metric creates one variable, {@code total}, that is used to compute the sum of values. + * This is ultimately returned as sum. * *

      If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java index 0146552433f..b6e50c3295a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -34,8 +34,8 @@ public class TopKCategoricalAccuracy extends MeanMetricWrappe private final int k; /** - * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for {@code k}, Number of - * top elements to look at for computing accuracy. + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for {@code k}, Number of top + * elements to look at for computing accuracy. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java index 5c65f8c469f..fd6b95df6d2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of true negatives. * - *

      If {@code sampleWeights} is given, calculates the sum of the weights of true negatives. - * This metric creates one local variable, {@code accumulator} that is used to keep track of - * the number of true negatives. + *

      If {@code sampleWeights} is given, calculates the sum of the weights of true negatives. This + * metric creates one local variable, {@code accumulator} that is used to keep track of the number + * of true negatives. * - *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code - * sampleWeights} of 0 to mask values. + *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code sampleWeights} of 0 + * to mask values. * * @param The data type for the metric result */ @@ -50,10 +50,10 @@ public TrueNegatives(Ops tf, long seed, Class type) { * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -66,10 +66,10 @@ public TrueNegatives(Ops tf, float threshold, long seed, Class type) { * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -96,10 +96,10 @@ public TrueNegatives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -113,10 +113,10 @@ public TrueNegatives(Ops tf, String name, float threshold, long seed, Class t * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java index f0dd8c42de5..90fe9142014 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of true positives. * - *

      If {@code sampleWeights} is given, calculates the sum of the weights of true positives. - * This metric creates one local variable, {@code accumulator} that is used to keep track of - * the number of true positives. + *

      If {@code sampleWeights} is given, calculates the sum of the weights of true positives. This + * metric creates one local variable, {@code accumulator} that is used to keep track of the number + * of true positives. * - *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code - * sampleWeights} of 0 to mask values. + *

      If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code sampleWeights} of 0 + * to mask values. * * @param The data type for the metric result */ @@ -50,10 +50,10 @@ public TruePositives(Ops tf, long seed, Class type) { * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -66,10 +66,10 @@ public TruePositives(Ops tf, float threshold, long seed, Class type) { * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -96,10 +96,10 @@ public TruePositives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -113,10 +113,10 @@ public TruePositives(Ops tf, String name, float threshold, long seed, Class t * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java index 88597cf85ec..b031d80d0ef 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -67,10 +67,9 @@ public ConfusionMatrixConditionCount( * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate - * @param threshold a threshold value in {@code [0, 1]}. A threshold is compared with - * prediction values to determine the truth value of predictions (i.e., above the threshold is - * {@code true}, below is {@code false}). One metric value is generated for each - * threshold value. + * @param threshold a threshold value in {@code [0, 1]}. A threshold is compared with prediction + * values to determine the truth value of predictions (i.e., above the threshold is {@code + * true}, below is {@code false}). One metric value is generated for each threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -91,10 +90,9 @@ public ConfusionMatrixConditionCount( * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate - * @param thresholds threshold values in {@code [0, 1]}. A threshold is compared with - * prediction values to determine the truth value of predictions (i.e., above the threshold is - * {@code true}, below is {@code false}). One metric value is generated for each - * threshold value. + * @param thresholds threshold values in {@code [0, 1]}. A threshold is compared with prediction + * values to determine the truth value of predictions (i.e., above the threshold is {@code + * true}, below is {@code false}). One metric value is generated for each threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -118,12 +116,13 @@ public ConfusionMatrixConditionCount( private void init() { Shape variableShape = Shape.of(this.thresholds.length); - Zeros zeros = new Zeros<>(getTF()); + Zeros zeros = new Zeros<>(); accumulator = getTF() .withName(getAccumulatorName()) - .variable(zeros.call(getTF().constant(variableShape), type)); - initializer = getTF().assign(accumulator, zeros.call(getTF().constant(variableShape), type)); + .variable(zeros.call(getTF(), getTF().constant(variableShape), type)); + initializer = + getTF().assign(accumulator, zeros.call(getTF(), getTF().constant(variableShape), type)); } /** @@ -189,7 +188,10 @@ public float[] getThresholds() { return this.thresholds; } - /** @return the accumulatorName */ + /** + * Gets the accumulatorName + * @return the accumulatorName + */ public String getAccumulatorName() { return accumulatorName; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index f89047e457d..76c21aebefc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -18,7 +18,7 @@ import org.tensorflow.types.family.TNumber; /** - * Interface for Metrics that wrap Loss functions. + * Interface for Metrics that wrap AbstractLoss functions. * * @param The data type of the predictions. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 37bdd5849ae..ec103197709 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -29,9 +29,9 @@ * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of * {@link MetricReduction#WEIGHTED_MEAN}. * - *

      The loss function calculates the loss between the {@code labels} and {@code predictions - * } then passes this loss to the {@link Mean} metric to calculate the weighted mean of the - * loss over many iterations or epochs + *

      The loss function calculates the loss between the {@code labels} and {@code predictions } then + * passes this loss to the {@link Mean} metric to calculate the weighted mean of the loss over many + * iterations or epochs * * @param The data type for the metric result */ @@ -63,7 +63,7 @@ public LossMetric getLoss() { } /** - * Sets the Loss function for this wrapper. + * Sets the AbstractLoss function for this wrapper. * * @param loss the loss function. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 40336233d21..51b8836ec83 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -59,8 +59,7 @@ public class MetricsHelper { "weights can not be broadcast to values."; /** - * Asserts that the {@code sampleWeights} can be broadcast to the same shape as {@code values - * } + * Asserts that the {@code sampleWeights} can be broadcast to the same shape as {@code values } * *

      In losses and metrics, limited weight broadcasting is supported. Weights must be either * scalar, or the same rank as the target values, with each dimension either 1, or the same as the @@ -69,8 +68,8 @@ public class MetricsHelper { * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. * @param values the values to which weights are applied. - * @return {@code Operation} with control dependencies to ensure {@code sampleWeight} - * can be broadcast to {@code values} + * @return {@code Operation} with control dependencies to ensure {@code sampleWeight} can be + * broadcast to {@code values} * @param the type of Operand * @throws NotBroadcastableException If static checks determine {@code sampleWeights} has an * incorrect shape that prohibit broadcasting to {@code values} @@ -114,10 +113,7 @@ public static Op assertBroadcastable( throw new NotBroadcastableException( String.format( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", - ASSERT_BROADCAST_ERROR_PREFIX, - i, - valuesShapeStatic, - weightsShapeStatic)); + ASSERT_BROADCAST_ERROR_PREFIX, i, valuesShapeStatic, weightsShapeStatic)); } } return tf.withSubScope("staticDimsCheckSuccess") @@ -307,24 +303,24 @@ public static List assertShapes( *

      For estimation of these metrics over a stream of data, the function creates an `update_op` * operation that updates the given variables. * - *

      {@code labels}, {@code predictions}, and {@code sampleWeight} tensors are - * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. {@code - * sampleWeight} is then broadcast to the shape of {@code predictions}. + *

      {@code labels}, {@code predictions}, and {@code sampleWeight} tensors are aligned by {@link + * LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. {@code sampleWeight} is then + * broadcast to the shape of {@code predictions}. * * @param tf the TensorFlow Ops * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding variables to update as values. If {@code multiLabel}, then the variable * shapes are (T, D), where T is the number of thresholds and D is the number of classes - * (after slicing by {@code classIndex}, if provided). If {@code multiLabels}, then - * the variable shapes are (T). + * (after slicing by {@code classIndex}, if provided). If {@code multiLabels}, then the + * variable shapes are (T). * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding initializer Operands to for {@code variablesToUpdate}. * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra * dimension of size 1 that would be squeezed. * @param predictions the predictions shape (N, Cx, P1?) - * @param thresholds thresholds in the range {@code [0, 1]}, or {@link #NEG_INF} is used when - * topK is set + * @param thresholds thresholds in the range {@code [0, 1]}, or {@link #NEG_INF} is used when topK + * is set * @param topK optional, indicates that only the top k predictions should be considered. Applied * before possibly slicing by {@code classIndex}. * @param classIndex optional, limits the prediction and labels to the specified class. This is an @@ -338,14 +334,14 @@ public static List assertShapes( * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES * without explicit multilabel handling (i.e. when the data is to be flattened). Must have - * shape (Dx), which is the same as (Cx) referenced above, except that if {@code classIndex - * } is provided, then the final dimension of Dx is 1. These weights will be broadcast - * across the 0th dimension (the examples dimension) of {@code predictions}. May be null. - * Must be null if {@code multiLabel}. + * shape (Dx), which is the same as (Cx) referenced above, except that if {@code classIndex } + * is provided, then the final dimension of Dx is 1. These weights will be broadcast across + * the 0th dimension (the examples dimension) of {@code predictions}. May be null. Must be + * null if {@code multiLabel}. * @param the data type for the variables - * @throws IllegalArgumentException If {@code predictions} and {@code labels} have - * mismatched shapes, or if {@code sampleWeight} is not null and its shape - * doesn't match {@code predictions}, or if {@code multiLabel && labelWeights != null}.. + * @throws IllegalArgumentException If {@code predictions} and {@code labels} have mismatched + * shapes, or if {@code sampleWeight} is not null and its shape doesn't match {@code + * predictions}, or if {@code multiLabel && labelWeights != null}.. * @return an op to update the given confusion matrix variables. */ @SuppressWarnings({"unchecked", "rawtypes"}) @@ -439,11 +435,13 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), if (classIndex != null) { // Slice to new shapes (N, Dx) - tLabels = tf.squeeze(tf.gather(tLabels, - tf.constant(new int[] {classIndex}), tf.constant(-1)), + tLabels = + tf.squeeze( + tf.gather(tLabels, tf.constant(new int[] {classIndex}), tf.constant(-1)), Squeeze.axis(Collections.singletonList(1L))); - tPredictions = tf.squeeze(tf.gather(tPredictions, - tf.constant(new int[] {classIndex}), tf.constant(-1)), + tPredictions = + tf.squeeze( + tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(-1)), Squeeze.axis(Collections.singletonList(1L))); } org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions); @@ -693,8 +691,7 @@ private static Operand filterTopK(Ops tf, Operand x, i // alias for mean /** - * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false - * } + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -706,8 +703,8 @@ public static Operand mean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is - * {@code false} + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is {@code + * false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -725,10 +722,9 @@ public static Operand mean( * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained - * with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is {@code + * false}, the rank of the tensor is reduced by 1 for each entry in {@code axes }. If {@code + * keepdims} is {@code true}, the reduced dimensions are retained with length 1. * @param the type of the operand * @return the mean of elements of {@code x}. */ @@ -742,10 +738,9 @@ public static Operand mean(Ops tf, Operand x, boolean * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained - * with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is {@code + * false}, the rank of the tensor is reduced by 1 for each entry in {@code axes }. If {@code + * keepdims} is {@code true}, the reduced dimensions are retained with length 1. * @param the data type of the Operand * @return the mean of elements of {@code x}. */ @@ -783,12 +778,12 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( *

      For example: * *

      {@code
      -   *     confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
      -   *          [[0 0 0 0 0]
      -   *           [0 0 1 0 0]
      -   *           [0 0 1 0 0]
      -   *           [0 0 0 0 0]
      -   *           [0 0 0 0 1]]
      +   * confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
      +   *      [[0 0 0 0 0]
      +   *       [0 0 1 0 0]
      +   *       [0 0 1 0 0]
      +   *       [0 0 0 0 0]
      +   *       [0 0 0 0 1]]
          * }
      * * Note that the possible labels are assumed to be {@code [0, 1, 2, 3,4]}, resulting in a 5x5 @@ -802,12 +797,12 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( * @param weights optional weights to be applied to the confusion matrix * @param type Data type of the confusion matrix. * @param the type of Operands - * @return A {@code Operand} of type {@code type} with shape {@code [n, n]} - * representing the confusion matrix, where {@code n} is the number of possible labels in - * the classification task. - * @throws IllegalArgumentException If both {@code predictions} and {@code labels} do - * not have compatible shapes, or if {@code weights} is not{@code null} and its - * shape is not compatible with {@code predictions}. + * @return A {@code Operand} of type {@code type} with shape {@code [n, n]} representing the + * confusion matrix, where {@code n} is the number of possible labels in the classification + * task. + * @throws IllegalArgumentException If both {@code predictions} and {@code labels} do not have + * compatible shapes, or if {@code weights} is not{@code null} and its shape is not compatible + * with {@code predictions}. */ // TODO should this be moved to FramnworkOps under math. public static Operand confusionMatrix( @@ -883,8 +878,7 @@ public static Operand confusionMatrix( } /** - * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false - * } + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -895,8 +889,8 @@ public static Operand booleanMean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is - * {@code false} + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is {@code + * false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -913,10 +907,9 @@ public static Operand booleanMean( * * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained - * with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is {@code + * false}, the rank of the tensor is reduced by 1 for each entry in {@code axes }. If {@code + * keepdims} is {@code true}, the reduced dimensions are retained with length 1. * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { @@ -929,10 +922,9 @@ public static Operand booleanMean(Ops tf, Operand x, boolean ke * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained - * with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is {@code + * false}, the rank of the tensor is reduced by 1 for each entry in {@code axes }. If {@code + * keepdims} is {@code true}, the reduced dimensions are retained with length 1. * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java index 60a6c1ea3df..e47ea4ea8e8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -87,9 +87,9 @@ protected SensitivitySpecificityBase( /** Initializes the Variables */ private void init() { Ops tf = getTF(); - Zeros zeros = new Zeros<>(tf); + Zeros zeros = new Zeros<>(); Shape varShape = Shape.of(numThresholds); - Operand zero = zeros.call(tf.constant(varShape), type); + Operand zero = zeros.call(tf, tf.constant(varShape), type); if (this.getTruePositives() == null) { @@ -228,8 +228,6 @@ public int getNumThresholds() { return numThresholds; } - - /** * Gets the thresholds * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 68157632557..0553b1edac7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -26,8 +26,8 @@ public class SetsOps { /** - * Computes set difference of elements in last dimension of {@code a} and {@code b} with - * {@code aMinusB} set to true. + * Computes set difference of elements in last dimension of {@code a} and {@code b} with {@code + * aMinusB} set to true. * *

      All but the last dimension of {@code a} and {@code b} must match * @@ -35,8 +35,8 @@ public class SetsOps { * @param a The first operand representing set {@code a} * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the * same. Elements along the last dimension contain the results of the set * operation. */ public static Operand difference(Ops tf, Operand a, Operand b) { @@ -53,8 +53,8 @@ public static Operand difference(Ops tf, Operand a, Op * @param b The other operand representing set {@code b} * @param aMinusB whether to subtract b from a, vs vice versa. * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the * same. Elements along the last dimension contain the results of the set * operation. */ public static Operand difference( @@ -69,8 +69,8 @@ public static Operand difference( * @param a The first operand representing set {@code a} * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the * same. Elements along the last dimension contain the results of the set * operation. */ public static Operand union(Ops tf, Operand a, Operand b) { @@ -84,8 +84,8 @@ public static Operand union(Ops tf, Operand a, Operand * @param a The first operand representing set {@code a} * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the * same. Elements along the last dimension contain the results of the set * operation. */ public static Operand intersection(Ops tf, Operand a, Operand b) { @@ -100,8 +100,8 @@ public static Operand intersection(Ops tf, Operand a, * @param b The other et operation operand * @param setOperation The set operation to perform, {@link Operation}. * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the same. Elements along the last dimension contain the results of the set * operation. */ public static Operand setOperation( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java index d28185ae041..7c3fda07ea9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java @@ -21,35 +21,72 @@ import java.util.Arrays; import java.util.List; +/** + * A class that represents a Symbolic shape. + * + *

      A Symbolic shape uses symbols to identify the relationship of the shape of an operand to + * underlying values that are not know until compute time. For example, "N" represent the number of + * examples, while "L" represents the number of labels. When the values later become known, the + * shape of the operand must conform the these symbolic values. + * + * @param The data type for the Operand. + */ public class SymbolicShape { private Operand operand; private List symbols = new ArrayList<>(); + /** + * Creates a SymbolicShape + * + * @param operand the Operand that needs to conform to the shape + * @param symbols the symbolic value for each dimension of the shape. + */ public SymbolicShape(Operand operand, String... symbols) { this.operand = operand; this.symbols.addAll(Arrays.asList(symbols)); } - /** @return the operand */ + /** + * Gets the operand + * + * @return the operand + */ public Operand getOperand() { return operand; } - /** @param operand the operand to set */ + /** + * Sets the operand + * + * @param operand the operand to set + */ public void setOperand(Operand operand) { this.operand = operand; } - /** @return the symbols */ + /** + * Gets the symbols associated with each dimension of the shape + * + * @return the symbols associated with each dimension of the shape + */ public List getSymbols() { return symbols; } - /** @param symbols the symbols to set */ + /** + * Sets teh symbols associated with each dimension of the shape + * + * @param symbols the symbols associated with each dimension of the shape + */ public void setSymbols(List symbols) { this.symbols = symbols; } + /** + * Gets the rank associated with this Symbolic Shape + * + * @return the rank associated with this Symbolic Shape + */ public int rank() { return this.symbols.size(); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index 6583465da2e..18b11700380 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -32,8 +32,8 @@ /** * Weight broadcasting operations. * - *

      In {@link org.tensorflow.framework.losses} and `{@link org.tensorflow.framework.metrics}, we support limited weight broadcasting. This file includes - * operations for those broadcasting rules. + *

      In {@link org.tensorflow.framework.losses} and `{@link org.tensorflow.framework.metrics}, we + * support limited weight broadcasting. This file includes operations for those broadcasting rules. */ public class WeightsBroadcastOps { @@ -46,10 +46,11 @@ public class WeightsBroadcastOps { * @param tf the TensorFlow Ops * @param weights the weights Operand * @param values Operand of values to which weights are applied. - * @return {@code Operation} raising a tensorflow InvalidArgumentError if {@code weights} has incorrect shape. {@link NoOp} if - * static checks determine {@code weights} has correct shape. + * @return {@code Operation} raising a tensorflow InvalidArgumentError if {@code weights} has + * incorrect shape. {@link NoOp} if static checks determine {@code weights} has correct shape. * @param the type of weights and values - * @throws IllegalArgumentException If static checks determine {@code weights} has incorrect shape. + * @throws IllegalArgumentException If static checks determine {@code weights} has incorrect + * shape. */ public static Op assertBroadcastable( Ops tf, Operand weights, Operand values) { @@ -81,14 +82,12 @@ public static Op assertBroadcastable( } for (int i = 0; i < valuesRankStatic; i++) { - if (weightsShapeStatic.size(i) != 1 && valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { + if (weightsShapeStatic.size(i) != 1 + && valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { throw new IllegalArgumentException( String.format( "%s Mismatch at dim %s. values.shape=%s weights.shape=%s.", - ASSERT_BROADCASTABLE_ERROR_PREFIX, - i, - valuesShapeStatic, - weightsShapeStatic)); + ASSERT_BROADCASTABLE_ERROR_PREFIX, i, valuesShapeStatic, weightsShapeStatic)); } } return tf.withSubScope("staticDimsCheckSuccess") @@ -105,12 +104,12 @@ public static Op assertBroadcastable( tf.constant("values.shape="), valuesShape, tf.constant("isScalar="), - isScalar); + isScalar); Operand isValidShape = tf.select( - isScalar, - isScalar, + isScalar, + isScalar, hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); return tf.assertThat(isValidShape, data); @@ -140,7 +139,8 @@ private static Operand hasValidNonscalarShape( } /** - * Checks that each dimension of the two shapes are the same size, or that the weight dimension size is 1. + * Checks that each dimension of the two shapes are the same size, or that the weight dimension + * size is 1. * * @param tf the TensorFlow Ops * @param weightsShape the shape of the weights @@ -152,7 +152,8 @@ private static Operand hasValidDims( tf = tf.withSubScope("hasInvalidDims"); Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); - Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); + Operand validDims = + tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); Operand weightsShape2d = tf.expandDims(weightsShape, tf.constant(-1)); Operand invalidDims = SetsOps.difference(tf, weightsShape2d, validDims); @@ -164,8 +165,7 @@ private static Operand hasValidDims( * Broadcast {@code weights} to the same shape as {@code values}. * *

      This returns a version of {@code weights} following the same broadcast rules as {@code - * mul(weights, - * values)}, but limited to the weights shapes allowed by {@code assertBroadcastable} + * mul(weights, values)}, but limited to the weights shapes allowed by {@code assertBroadcastable} * When computing a weighted average, use this function to broadcast {@code weights} before * summing them; e.g., {@code reduceSum(w * v) / reduceSum(_broadcast_weights(w, v))}. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/AbstractRegularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/AbstractRegularizer.java new file mode 100644 index 00000000000..25535292db3 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/AbstractRegularizer.java @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.framework.losses.impl.AbstractLoss; + +/** + * Base class for Regularizers + * + *

      Regularizers allow you to apply penalties on layer parameters or layer activity during + * optimization. These penalties are summed into the loss function that the network optimizes. + */ +public abstract class AbstractRegularizer implements Regularizer { + + public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; + + private final String name; + + /** Creates a AbstractRegularizer, using {@link Class#getSimpleName()} for the name */ + protected AbstractRegularizer() { + this(null); + } + /** + * Creates a AbstractRegularizer + * + * @param name the name of this regularizer, if null use {@link Class#getSimpleName()} for the + * name. + */ + protected AbstractRegularizer(String name) { + this.name = name == null ? this.getClass().getSimpleName() : name; + } + + /** + * Returns this AbstractRegularizer as a AbstractLoss This is a convenience to use regularize a + * loss. Only sampleWeights are applied to the regularizer. + * + * @return this AbstractRegularizer as a AbstractLoss + */ + public AbstractLoss asLoss() { + return new RegularizerLoss(this); + } + + /** + * Gets the name for this regularizer + * + * @return the name for this regularizer + */ + public String getName() { + return name; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java index 7c8c2a1360a..4b7aa1af620 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -14,8 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.regularizers; -import org.tensorflow.op.Ops; - /** * A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) * Regression, regularization penalty. @@ -24,24 +22,43 @@ */ public class L1 extends L1L2 { + /** + * Create a regularizer that applies an L1 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} and a name based on the class name. + */ + public L1() { + this(null, DEFAULT_REGULARIZATION_PENALTY); + } + /** * Create a regularizer that applies an L1 regularization penalty of {@link * #DEFAULT_REGULARIZATION_PENALTY} * - * @param tf the TensorFlow Ops + * @param name the name for this AbstractRegularizer + */ + public L1(String name) { + this(name, DEFAULT_REGULARIZATION_PENALTY); + } + + /** + * Create a regularizer that applies an L1 regularization penalty and a name based on the class + * name. + * + * @param l1 the L1 regularization penalty + * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. */ - public L1(Ops tf) { - this(tf, DEFAULT_REGULARIZATION_PENALTY); + public L1(float l1) { + this(null, l1); } /** * Create a regularizer that applies an L1 regularization penalty * - * @param tf the TensorFlow Ops + * @param name the name for this AbstractRegularizer * @param l1 the L1 regularization penalty * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. */ - public L1(Ops tf, float l1) { - super(tf, l1, 0f); + public L1(String name, float l1) { + super(name, l1, 0f); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index 29e411f9897..6dfaf3f0d47 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -19,6 +19,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A regularizer that applies both L1 and L2 regularization penalties. * @@ -29,33 +31,39 @@ *

      The L2 regularization penalty is computed as * *

      loss = l2 * reduceSum(square(x))
      - * */ -public class L1L2 extends Regularizer { +public class L1L2 extends AbstractRegularizer { private final float l1; private final float l2; + /** Creates an L1L2 regularizer with no l1 or l2 penalty with zero penalty */ + public L1L2() { + this(DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); + } + /** - * Creates an L1L2 regularizer with no l1 or l2 penalty with zero penalty + * Creates an L1L2 regularizer * - * @param tf the TensorFlow Ops + * @param l1 L1 regularization factor, if null it is set to 0. + * @param l2 L2 regularization factor, if null it is set to 0. + * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} + * of {@link Float#isInfinite} */ - public L1L2(Ops tf) { - this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); + public L1L2(float l1, float l2) { + this(null, l1, l2); } /** * Creates an L1L2 regularizer * - * @param tf the TensorFlow Ops * @param l1 L1 regularization factor, if null it is set to 0. * @param l2 L2 regularization factor, if null it is set to 0. * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} * of {@link Float#isInfinite} */ - public L1L2(Ops tf, float l1, float l2) { - super(tf); + public L1L2(String name, float l1, float l2) { + super(name); if (Float.isNaN(l1) || Float.isInfinite(l1)) { throw new IllegalArgumentException( String.format( @@ -73,25 +81,23 @@ public L1L2(Ops tf, float l1, float l2) { this.l2 = l2; } - /** {@inheritDoc} */ @Override - public Operand call(Operand input) { - Ops tf = getTF(); + public Operand call(Ops tf, Operand input) { if (this.getL1() == 0f && this.getL2() == 0f) { - return tf.dtypes.cast(tf.constant(0), input.type()); + return cast(tf, tf.constant(0), input.type()); } - Operand regularization = tf.dtypes.cast(tf.constant(0), input.type()); + Operand regularization = cast(tf, tf.constant(0), input.type()); if (this.getL1() != 0.f) { - Operand l1Op = tf.dtypes.cast(tf.constant(this.getL1()), input.type()); + Operand l1Op = cast(tf, tf.constant(this.getL1()), input.type()); Operand abs = tf.math.abs(input); Operand reduceSum = tf.reduceSum(abs, LossesHelper.allAxes(tf, input)); regularization = tf.math.add(regularization, tf.math.mul(l1Op, reduceSum)); } if (this.getL2() != 0.f) { - Operand l2Op = tf.dtypes.cast(tf.constant(this.getL2()), input.type()); + Operand l2Op = cast(tf, tf.constant(this.getL2()), input.type()); Operand sqr = tf.math.square(input); Operand reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input)); regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java index 7b8f5b28a70..9092b80b08f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -14,8 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.regularizers; -import org.tensorflow.op.Ops; - /** * A regularizer that applies a L2 (Ridge Regression) regularization penalty. * @@ -23,24 +21,43 @@ */ public class L2 extends L1L2 { + /** + * Create a regularizer that applies an L2 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} and a name based on the class name. + */ + public L2() { + this(null, DEFAULT_REGULARIZATION_PENALTY); + } + /** * Create a regularizer that applies an L2 regularization penalty of {@link * #DEFAULT_REGULARIZATION_PENALTY} * - * @param tf the TensorFlow Ops + * @param name the name for this AbstractRegularizer + */ + public L2(String name) { + this(name, DEFAULT_REGULARIZATION_PENALTY); + } + + /** + * Create a regularizer that applies an L1 regularization penalty and a name based on the class + * name. + * + * @param l2 the L2 regularization penalty + * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. */ - public L2(Ops tf) { - this(tf, DEFAULT_REGULARIZATION_PENALTY); + public L2(float l2) { + this(null, l2); } /** * Create a regularizer that applies an L1 regularization penalty * - * @param tf the TensorFlow Ops + * @param name the name for this AbstractRegularizer * @param l2 the L2 regularization penalty * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. */ - public L2(Ops tf, float l2) { - super(tf, 0f, l2); + public L2(String name, float l2) { + super(name, 0f, l2); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java index 5d9ff0e3e10..085f06e115c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,77 +15,18 @@ package org.tensorflow.framework.regularizers; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.Loss; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * Base class for Regularizers - * - *

      Regularizers allow you to apply penalties on layer parameters or layer activity during - * optimization. These penalties are summed into the loss function that the network optimizes. - */ -public abstract class Regularizer { - - public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; - - private final Ops tf; - private final String name; - - /** - * Creates a Regularizer, using {@link Class#getSimpleName()} for the name - * - * @param tf the TensorFlow ops. - */ - protected Regularizer(Ops tf) { - this(tf, null); - } - /** - * Creates a Regularizer - * - * @param tf the TensorFlow ops. - * @param name the name of this regularizer, if null use {@link Class#getSimpleName()} for the - * name. - */ - protected Regularizer(Ops tf, String name) { - this.tf = tf; - this.name = name == null ? this.getClass().getSimpleName() : name; - } - - /** - * Returns this Regularizer as a Loss This is a convenience to use regularize a loss. Only - * sampleWeights are applied to the regularizer. - * - * @return this Regularizer as a Loss - */ - public Loss asLoss() { - return new RegularizerLoss(this.tf, this); - } +public interface Regularizer { /** * Computes a regularization penalty from an input. * + * @param tf the TensorFlow Ops * @param input the weighted input * @return the result of computing the regularization penalty * @param the data type of the input and result */ - public abstract Operand call(Operand input); - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - public Ops getTF() { - return tf; - } - - /** - * Gets the name for this regularizer - * - * @return the name for this regularizer - */ - public String getName() { - return name; - } + Operand call(Ops tf, Operand input); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java index 582cd038f8f..11c7ee492e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java @@ -15,50 +15,49 @@ package org.tensorflow.framework.regularizers; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.Loss; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; /** - * A Regularizer call wrapped as a Loss instance + * A AbstractRegularizer call wrapped as a AbstractLoss instance * *

      This class facilitates using a regularizer as a loss, only sampleWeights are * regularized. */ -class RegularizerLoss extends Loss { +class RegularizerLoss extends AbstractLoss { - private final Regularizer regularizer; + private final AbstractRegularizer regularizer; /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT} + * Creates a AbstractLoss using {@link Class#getSimpleName()} as the name and a AbstractLoss + * Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops * @param regularizer the regularizer used to calculate the loss */ - public RegularizerLoss(Ops tf, Regularizer regularizer) { - this(tf, null, regularizer); + public RegularizerLoss(AbstractRegularizer regularizer) { + this(null, regularizer); } /** - * Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * Creates a AbstractLoss using a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops - * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. + * @param name the name of this AbstractLoss, if null the name will be {@link + * Class#getSimpleName()}. * @param regularizer the regularizer used to calculate the loss */ - public RegularizerLoss(Ops tf, String name, Regularizer regularizer) { - super(tf, name); + public RegularizerLoss(String name, AbstractRegularizer regularizer) { + super(name); this.regularizer = regularizer; } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { if (sampleWeights == null) { throw new IllegalArgumentException("sampleWeights cannot be null"); } - return regularizer.call(sampleWeights); + return regularizer.call(tf, sampleWeights); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java index 914b94dfada..9f3fa75e95d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java @@ -14,36 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - -/** @author Jim Clarke */ public class ELUTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ELUTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of ELU call method */ @Test public void testCallFloat() { @@ -52,8 +33,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ELU instance = new ELU<>(tf); - Operand result = instance.call(tf.constant(input)); + ELU instance = new ELU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -66,8 +47,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ELU instance = new ELU<>(tf); - Operand result = instance.call(tf.constant(input)); + ELU instance = new ELU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -80,8 +61,8 @@ public void testAlpha() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ELU instance = new ELU<>(tf, 2.0f); - Operand result = instance.call(tf.constant(input)); + ELU instance = new ELU<>(2.0f); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java index 1157c582168..f82c19987d1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java @@ -14,35 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class ExponentialTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ExponentialTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of Exponential call method. */ @Test public void testCallFloat() { @@ -60,8 +42,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Exponential instance = new Exponential<>(tf); - Operand result = instance.call(tf.constant(input)); + Exponential instance = new Exponential<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -78,8 +60,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Exponential instance = new Exponential<>(tf); - Operand result = instance.call(tf.constant(input)); + Exponential instance = new Exponential<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java index 35f57c47f66..0e32201c3e6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java @@ -14,35 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class HardSigmoidTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public HardSigmoidTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of HardSigmoid call method. */ @Test public void testCallFloat() { @@ -51,8 +33,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - HardSigmoid instance = new HardSigmoid<>(tf); - Operand result = instance.call(tf.constant(input)); + HardSigmoid instance = new HardSigmoid<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -65,8 +47,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - HardSigmoid instance = new HardSigmoid<>(tf); - Operand result = instance.call(tf.constant(input)); + HardSigmoid instance = new HardSigmoid<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java index 7974035c680..817940688e8 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -26,20 +26,6 @@ public class LinearTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public LinearTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of Linear call method. */ @Test public void testCallInt() { @@ -48,8 +34,8 @@ public void testCallInt() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Linear instance = new Linear<>(tf); - Operand result = instance.call(tf.constant(input)); + Linear instance = new Linear<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -62,8 +48,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Linear instance = new Linear<>(tf); - Operand result = instance.call(tf.constant(input)); + Linear instance = new Linear<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -76,8 +62,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Linear instance = new Linear<>(tf); - Operand result = instance.call(tf.constant(input)); + Linear instance = new Linear<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java index a0aa2c4b453..94f803d6b1c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java @@ -14,30 +14,20 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; /** @author Jim Clarke */ public class ReLUTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ReLUTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of ReLU call method */ @Test public void testCallFloat() { @@ -46,8 +36,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -60,8 +50,8 @@ public void testCallInt() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -74,8 +64,8 @@ public void testCallLong() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -88,9 +78,9 @@ public void testCallFloat16() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); + ReLU instance = new ReLU<>(); Operand result = - instance.call(tf.dtypes.cast(tf.constant(input), TFloat16.class)); + instance.call(tf, tf.dtypes.cast(tf.constant(input), TFloat16.class)); session.evaluate(tf.dtypes.cast(tf.constant(expected), TFloat16.class), result); } } @@ -103,8 +93,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -112,12 +102,12 @@ public void testCallDouble() { @Test public void testAlpha() { double[] input = {-10., -5., 0.0, 5., 10.}; - double[] expected = {-5. , -2.5, 0., 5., 10.}; + double[] expected = {-5., -2.5, 0., 5., 10.}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf, 0.5f, ReLU.MAX_VALUE_DEFAULT, ReLU.THRESHOLD_DEFAULT); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(0.5f, ReLU.MAX_VALUE_DEFAULT, ReLU.THRESHOLD_DEFAULT); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -129,8 +119,8 @@ public void testMaxValue() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf, ReLU.ALPHA_DEFAULT, 5, ReLU.THRESHOLD_DEFAULT); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(ReLU.ALPHA_DEFAULT, 5, ReLU.THRESHOLD_DEFAULT); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -138,12 +128,12 @@ public void testMaxValue() { @Test public void testThreshold() { double[] input = {-10., -5., 0.0, 5., 10.}; - double[] expected = {-0., -0., 0., 0., 10.}; + double[] expected = {-0., -0., 0., 0., 10.}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf, ReLU.ALPHA_DEFAULT, ReLU.MAX_VALUE_DEFAULT, 5.0f); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(ReLU.ALPHA_DEFAULT, ReLU.MAX_VALUE_DEFAULT, 5.0f); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java index 8bad6f1f066..ef4644df18e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java @@ -14,35 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SELUTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SELUTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of SELU call method */ @Test public void testCallFloat() { @@ -53,8 +35,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SELU instance = new SELU<>(tf); - Operand result = instance.call(tf.constant(input)); + SELU instance = new SELU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -71,8 +53,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SELU instance = new SELU<>(tf); - Operand result = instance.call(tf.constant(input)); + SELU instance = new SELU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java index 9dca622c3ec..0c59eeaba6e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java @@ -14,34 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SigmoidTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SigmoidTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - /** Test of Sigmoid call method */ @Test public void testCallFloat() { @@ -59,8 +42,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Sigmoid instance = new Sigmoid<>(tf); - Operand result = instance.call(tf.constant(input)); + Sigmoid instance = new Sigmoid<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -77,8 +60,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Sigmoid instance = new Sigmoid<>(tf); - Operand result = instance.call(tf.constant(input)); + Sigmoid instance = new Sigmoid<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java index 05ec3a4f716..aeb971905a2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java @@ -14,35 +14,18 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SoftmaxTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SoftmaxTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - /** Test of Softmax method, of class Activations. */ @Test public void testSoftmaxOpsOperandFloat() { @@ -54,8 +37,8 @@ public void testSoftmaxOpsOperandFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -71,8 +54,8 @@ public void testSoftmaxOpsOperandDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -88,8 +71,8 @@ public void testSoftmaxOpsOperandDoubleNegative() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -99,14 +82,14 @@ public void testSoftmaxOpsOperandDoubleNegative() { public void testSoftmax1D() { double[] input = {1, -2, 3, -4, -5, 6, 7, 8}; double[] expected = { - 6.0352829e-04, 3.0047902e-05, 4.4595040e-03, 4.0665414e-06, - 1.4959969e-06, 8.9571528e-02, 2.4348068e-01, 6.6184908e-01 + 6.0352829e-04, 3.0047902e-05, 4.4595040e-03, 4.0665414e-06, + 1.4959969e-06, 8.9571528e-02, 2.4348068e-01, 6.6184908e-01 }; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -116,14 +99,14 @@ public void testSoftmax1D() { public void testSoftmax3D() { double[][][] input = {{{1, -2}, {3, -4}}, {{-5, 6}, {-7, 8}}}; double[][][] expected = { - {{9.5257413e-01, 4.7425874e-02}, {9.9908900e-01, 9.1105123e-04}}, - {{1.6701422e-05, 9.9998331e-01}, {3.0590220e-07, 9.9999964e-01}} + {{9.5257413e-01, 4.7425874e-02}, {9.9908900e-01, 9.1105123e-04}}, + {{1.6701422e-05, 9.9998331e-01}, {3.0590220e-07, 9.9999964e-01}} }; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java index a17f2650d62..e896807d9f7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -26,20 +26,6 @@ public class SoftplusTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SoftplusTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of Softplus call method */ @Test public void testCallFloat() { @@ -50,8 +36,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softplus instance = new Softplus<>(tf); - Operand result = instance.call(tf.constant(input)); + Softplus instance = new Softplus<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -68,8 +54,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softplus instance = new Softplus<>(tf); - Operand result = instance.call(tf.constant(input)); + Softplus instance = new Softplus<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java index 43591ab4761..2f9a17caf59 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -26,20 +26,6 @@ public class SoftsignTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SoftsignTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of Softsign call method */ @Test public void testCallFloat() { @@ -48,8 +34,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softsign instance = new Softsign<>(tf); - Operand result = instance.call(tf.constant(input)); + Softsign instance = new Softsign<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -71,8 +57,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softsign instance = new Softsign<>(tf); - Operand result = instance.call(tf.constant(input)); + Softsign instance = new Softsign<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java index 7576789320b..8dabfaf379a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java @@ -14,35 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SwishTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SwishTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of Swish call method */ @Test public void testCallFloat() { @@ -60,8 +42,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Swish instance = new Swish<>(tf); - Operand result = instance.call(tf.constant(input)); + Swish instance = new Swish<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -83,8 +65,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Swish instance = new Swish<>(tf); - Operand result = instance.call(tf.constant(input)); + Swish instance = new Swish<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java index 5162e141c44..3988ec55bb3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -25,20 +25,6 @@ public class TanhTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public TanhTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of Tanh call method. */ @Test public void testCallFloat() { @@ -52,8 +38,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Tanh instance = new Tanh<>(tf); - Operand result = instance.call(tf.constant(input)); + Tanh instance = new Tanh<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -71,8 +57,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Tanh instance = new Tanh<>(tf); - Operand result = instance.call(tf.constant(input)); + Tanh instance = new Tanh<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java index 1f80388e88f..259d6a963b5 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java @@ -35,8 +35,8 @@ public void testCall() { for (AtomicInteger i = new AtomicInteger(); i.get() < testValues.length; i.getAndIncrement()) { - MaxNorm instance = new MaxNorm(tf, testValues[i.get()]); - Operand result = instance.call(weights); + MaxNorm instance = new MaxNorm(testValues[i.get()]); + Operand result = instance.call(tf, weights); session.evaluate(result, v -> v.floatValue() <= testValues[i.get()]); } } @@ -47,13 +47,13 @@ public void testCall1() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MaxNorm instance = new MaxNorm(tf, 2.0); + MaxNorm instance = new MaxNorm(2.0); Operand weights = tf.constant( new float[][] { {0, 1, 3, 3}, {0, 0, 0, 3}, {0, 0, 0, 3}, }); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); float[] expected = { 0, 1, 2, 1.1547005f, 0, 0, 0, 1.1547005f, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java index 8c2c3a54ff9..8b4c4007096 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java @@ -39,8 +39,8 @@ public void testCall() { for (AtomicInteger i = new AtomicInteger(); i.get() < testValues.length; i.getAndIncrement()) { - MinMaxNorm instance = new MinMaxNorm(tf, testValues[i.get()], testValues[i.get()] * 2); - Operand result = instance.call(weights); + MinMaxNorm instance = new MinMaxNorm(testValues[i.get()], testValues[i.get()] * 2); + Operand result = instance.call(tf, weights); if (tfMode == TestSession.Mode.EAGER) evaluate(session, result.asTensor(), testValues[i.get()]); else diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java index 6a6fdc13536..1a24c188860 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java @@ -17,8 +17,8 @@ public void testTFloat32() { Ops tf = session.getTF(); float[][] array = {{-1, 2, -3, 4}, {-10, 11, 12, -13}}; Operand weights = tf.constant(array); - NonNeg instance = new NonNeg(tf); - Operand result = instance.call(weights); + NonNeg instance = new NonNeg(); + Operand result = instance.call(tf, weights); float[] expected = {0, 2, 0, 4, 0, 11, 12, 0}; session.evaluate(expected, result); } @@ -31,8 +31,8 @@ public void testTFloat64() { Ops tf = session.getTF(); final double[][] array = {{-1, 2, -3, 4}, {-10, 11, 12, -13}}; Operand weights = tf.constant(array); - NonNeg instance = new NonNeg(tf); - Operand result = instance.call(weights); + NonNeg instance = new NonNeg(); + Operand result = instance.call(tf, weights); double[] expected = {0, 2, 0, 4, 0, 11, 12, 0}; session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java index 6437ebcd760..9c784b7f31e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java @@ -28,8 +28,8 @@ public void testTFloat32() { }; Operand weights = tf.constant(array); - UnitNorm instance = new UnitNorm(tf, 1); - Operand result = instance.call(weights); + UnitNorm instance = new UnitNorm(1); + Operand result = instance.call(tf, weights); Operand expected = tf.constant(expectedArray); session.evaluate(expected, result); } @@ -50,9 +50,9 @@ public void testCallTFloat64() { {{0.72920675, 0.40984813, 0.55712338}, {0.68429305, 0.91215323, 0.83042956}}, {{0.97694125, 0.99972269, 0.13576831}, {0.21350717, 0.02353181, 0.99074035}} }; - UnitNorm instance = new UnitNorm(tf, 1); + UnitNorm instance = new UnitNorm(1); Operand weights = tf.constant(array); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); Operand expected = tf.constant(expectedArray); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java index 4e81e0620e6..9291e5f83ef 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java @@ -14,12 +14,18 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; @@ -29,20 +35,6 @@ public class ConstantTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ConstantTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Constant. */ @Test public void testCallUInt() { @@ -51,8 +43,9 @@ public void testCallUInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 0xf); - Operand operand = instance.call(tf.constant(shape), TUint8.class); + Constant instance = new Constant<>(0xf); + + Operand operand = instance.call(tf, tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -67,8 +60,9 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 0xf); - Operand operand = instance.call(tf.constant(shape), TInt32.class); + Constant instance = new Constant<>(0xf); + + Operand operand = instance.call(tf, tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -83,8 +77,9 @@ public void testCallLong() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 0xffL); - Operand operand = instance.call(tf.constant(shape), TInt64.class); + Constant instance = new Constant<>(0xffL); + + Operand operand = instance.call(tf, tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -97,8 +92,9 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 12.F); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Constant instance = new Constant<>(12.F); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -112,8 +108,9 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 11.); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Constant instance = new Constant<>(11.); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -129,8 +126,9 @@ public void testCallString() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 22); - instance.call(tf.constant(shape), TString.class); + Constant instance = new Constant<>(22); + + instance.call(tf, tf.constant(shape), TString.class); fail("IllegalArgumentException should have been thrown for TString"); } }); @@ -145,8 +143,9 @@ public void testCallBool() { Shape shape = Shape.of(2, 2); Boolean[] expected = {true, true, true, true}; - Constant instance = new Constant<>(tf, true); - Operand operand = instance.call(tf.constant(shape), TBool.class); + Constant instance = new Constant<>(true); + + Operand operand = instance.call(tf, tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -158,9 +157,10 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 11.); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Constant instance = new Constant<>(11.); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java index e9769806928..166011c3b64 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.VarianceScaling.Distribution; import org.tensorflow.framework.utils.TestSession; @@ -29,20 +29,6 @@ public class GlorotTest { private static final long SEED = 1000L; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public GlorotTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Glorot. */ @Test public void testCallNormalFloat() { @@ -51,9 +37,9 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Glorot instance = new Glorot<>(Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,8 +54,9 @@ public void testCallNormalDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,8 +69,9 @@ public void testCallUniformFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Glorot instance = new Glorot<>(Distribution.UNIFORM, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -97,8 +85,9 @@ public void testCallUniformDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.UNIFORM, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -109,9 +98,10 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -122,9 +112,10 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.UNIFORM, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -135,10 +126,10 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = - new Glorot<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.NORMAL, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java index 8953fa3005e..7b183358f85 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.VarianceScaling.Distribution; import org.tensorflow.framework.utils.TestSession; @@ -29,20 +29,6 @@ public class HeTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; int counter; - public HeTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class He. */ @Test public void testCallNormalFloat() { @@ -51,8 +37,9 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + He instance = new He<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -66,8 +53,9 @@ public void testCallNormalDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -80,8 +68,9 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + He instance = new He<>(Distribution.UNIFORM, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -95,8 +84,9 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.UNIFORM, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -107,9 +97,10 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -120,9 +111,10 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.UNIFORM, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -133,9 +125,10 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.NORMAL, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java index 6eee5473937..3f5c6cdb363 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java @@ -14,37 +14,19 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; /** Test the Identity initializer */ public class IdentityTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public IdentityTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Constant. */ @Test public void testCallFloat() { @@ -64,8 +46,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Identity instance = new Identity<>(tf, 2.); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Identity instance = new Identity<>(2.); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -90,8 +72,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Identity instance = new Identity<>(tf, 2.); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Identity instance = new Identity<>(2.); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -103,9 +85,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Identity instance = new Identity<>(tf, 2.); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Identity instance = new Identity<>(2.); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java index 336850a5549..8858bac13dd 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.VarianceScaling.Distribution; import org.tensorflow.framework.utils.TestSession; @@ -29,20 +29,6 @@ public class LeCunTest { private static final long SEED = 1000L; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public LeCunTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class LeCun. */ @Test public void testCallNormalFloat() { @@ -51,8 +37,8 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + LeCun instance = new LeCun<>(Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -66,8 +52,8 @@ public void testCallNormalDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -80,8 +66,8 @@ public void testCallUniformFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + LeCun instance = new LeCun<>(Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -95,8 +81,8 @@ public void testCallUniformDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -107,9 +93,9 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.TRUNCATED_NORMAL, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -120,9 +106,9 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -133,9 +119,9 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.NORMAL, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java index 053ba5dd7ff..4872ce7ad8e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java @@ -14,12 +14,18 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; @@ -29,20 +35,6 @@ public class OnesTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public OnesTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Ones. */ @Test public void testCallUInt() { @@ -51,8 +43,8 @@ public void testCallUInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TUint8.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -65,8 +57,8 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt32.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -79,8 +71,8 @@ public void testCallLong() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt64.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -93,8 +85,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -108,8 +100,8 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -125,8 +117,8 @@ public void testCallString() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - instance.call(tf.constant(shape), TString.class); + Ones instance = new Ones<>(); + instance.call(tf, tf.constant(shape), TString.class); fail("IllegalArgumentException should have been thrown for TString"); } }); @@ -140,8 +132,8 @@ public void testCallBool() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TBool.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -153,9 +145,23 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Ones instance = new Ones<>(); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); + session.evaluate(operand1, operand2); + } + } + + @Test + public void testFunctional() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(2, 2); + + Initializer instance = (ltf, dims, type) -> ltf.ones(dims, type); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java index 22b89d9177c..c933e669dfd 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java @@ -14,17 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; /** Test the Orthogonal initializer */ public class OrthogonalTest { @@ -33,20 +29,6 @@ public class OrthogonalTest { private static final double GAIN_VALUE = 1.0; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public OrthogonalTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Orthogonal. */ @Test public void testCallFloat() { @@ -156,8 +138,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Orthogonal instance = new Orthogonal<>(GAIN_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -271,8 +253,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Orthogonal instance = new Orthogonal<>(GAIN_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -284,9 +266,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Orthogonal instance = new Orthogonal<>(GAIN_VALUE, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java index 3b2b3bdb243..dada058af42 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -30,20 +30,6 @@ public class RandomNormalTest { private static final double STDDEV_VALUE = 3.0; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public RandomNormalTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class RandomNormal. */ @Test public void testCalltestSoftmaxFloat() { @@ -52,9 +38,8 @@ public void testCalltestSoftmaxFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = - new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + RandomNormal instance = new RandomNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,9 +53,8 @@ public void testCalltestSoftmaxDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = - new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + RandomNormal instance = new RandomNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,10 +66,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = - new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + RandomNormal instance = new RandomNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java index 23e26083a9b..1a1b3f755b7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -31,20 +31,6 @@ public class RandomUniformTest { private static final double MAX_VALUE = 10.0; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public RandomUniformTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class RandomUniform. */ @Test public void testCallInt() { @@ -53,9 +39,8 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TInt32.class); + RandomUniform instance = new RandomUniform<>(MIN_VALUE, MAX_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -68,9 +53,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + RandomUniform instance = new RandomUniform<>(MIN_VALUE, MAX_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -84,9 +68,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + RandomUniform instance = new RandomUniform<>(MIN_VALUE, MAX_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -98,10 +81,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + RandomUniform instance = new RandomUniform<>(MIN_VALUE, MAX_VALUE, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java index 96bf915e199..6ea19fde349 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -30,20 +30,6 @@ public class TruncatedNormalTest { private static final double STDDEV_VALUE = 3.0; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public TruncatedNormalTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class TruncatedNormal. */ @Test public void testCallFloat() { @@ -52,9 +38,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = - new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + TruncatedNormal instance = new TruncatedNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,9 +53,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = - new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + TruncatedNormal instance = new TruncatedNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,10 +66,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = - new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + TruncatedNormal instance = new TruncatedNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java index 159affb07e2..56aa95ecf73 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -28,20 +28,6 @@ public class VarianceScalingTest { private static final long SEED = 1000L; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public VarianceScalingTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class VarianceScaling. */ @Test public void testCallFloat1FanInTruncatedNormal() { @@ -52,12 +38,11 @@ public void testCallFloat1FanInTruncatedNormal() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -73,12 +58,11 @@ public void testCallDouble1FanInTruncatedNormal() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -93,12 +77,8 @@ public void testCallFloat1FanInNormal() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, - 1.0, - VarianceScaling.Mode.FAN_IN, - VarianceScaling.Distribution.NORMAL, - SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -114,12 +94,8 @@ public void testCalltestSoftmaxDouble1FanInNormal() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, - 1.0, - VarianceScaling.Mode.FAN_IN, - VarianceScaling.Distribution.NORMAL, - SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -134,8 +110,8 @@ public void testCalltestSoftmaxFloat1FanInUNIFORM() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -151,8 +127,8 @@ public void testCalltestSoftmaxDouble1FanInUNIFORM() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -166,9 +142,9 @@ public void testReproducible1() { VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -182,13 +158,9 @@ public void testReproducible2() { VarianceScaling instance = new VarianceScaling<>( - tf, - 1.0, - VarianceScaling.Mode.FAN_IN, - VarianceScaling.Distribution.NORMAL, - SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -202,13 +174,12 @@ public void testReproducible3() { VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_OUT, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -222,9 +193,9 @@ public void testReproducible4() { VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_AVG, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_AVG, VarianceScaling.Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java index 21bad6ff360..772baee1b61 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java @@ -14,32 +14,24 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; /** Test the Zeros initializer */ public class ZerosTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ZerosTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Zeros. */ @Test public void testCallUInt() { @@ -48,8 +40,8 @@ public void testCallUInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TUint8.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -62,8 +54,8 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt32.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -76,8 +68,8 @@ public void testCallLong() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt64.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -90,8 +82,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -105,8 +97,8 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -119,8 +111,8 @@ public void testCallString() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TString.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TString.class); session.evaluateString(operand, String::isEmpty); } } @@ -134,8 +126,8 @@ public void testCallBool() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TBool.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -147,9 +139,23 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Zeros instance = new Zeros<>(); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); + session.evaluate(operand1, operand2); + } + } + + @Test + public void testFunctional() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(2, 2); + + Initializer instance = (ltf, dims, type) -> ltf.zeros(dims, type); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java index d2128b80839..0b662414e8f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java @@ -32,11 +32,12 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); @@ -48,9 +49,9 @@ public void testAllCorrectUnweighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new BinaryCrossentropy(tf, true); + instance = new BinaryCrossentropy(true); - loss = instance.call(yTrue, logits); + loss = instance.call(tf, yTrue, logits); testSession.evaluate(expected, loss); } } @@ -67,7 +68,8 @@ public void testInvalidPredictionsRange() { catchClass, () -> { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}; float[] predArray = {2f, 1f, -1f, 0f}; Operand yTrue = @@ -75,7 +77,7 @@ public void testInvalidPredictionsRange() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -87,12 +89,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 1f, 0f}; float[] predArray = {1f, 1f, 1f, 0f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 3.83331f; testSession.evaluate(expected, loss); @@ -105,8 +108,9 @@ public void testUnweighted() { Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); - instance = new BinaryCrossentropy(tf, true); - loss = instance.call(yTrue1, logits); + instance = new BinaryCrossentropy(true); + + loss = instance.call(tf, yTrue1, logits); expected = 33.33333f; testSession.evaluate(expected, loss); } @@ -118,13 +122,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 1f, 0f}; float[] predArray = {1f, 1f, 1f, 0f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 8.816612f; testSession.evaluate(expected, loss); @@ -137,8 +142,9 @@ public void testScalarWeighted() { Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); - instance = new BinaryCrossentropy(tf, true); - loss = instance.call(yTrue1, logits, sampleWeight); + instance = new BinaryCrossentropy(true); + + loss = instance.call(tf, yTrue1, logits, sampleWeight); expected = 76.66667f; testSession.evaluate(expected, loss); } @@ -149,7 +155,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 1f, 0f}; float[] predArray = {1f, 1f, 1f, 0f}; float[] sampleWeightArray = {1.2f, 3.4f}; @@ -157,7 +164,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); Operand sampleWeight = tf.reshape(tf.constant(sampleWeightArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 4.59997f; testSession.evaluate(expected, loss); @@ -172,8 +179,9 @@ public void testSampleWeighted() { Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight1 = tf.constant(sampleWeightArray1); - instance = new BinaryCrossentropy(tf, true); - loss = instance.call(yTrue1, logits, sampleWeight1); + instance = new BinaryCrossentropy(true); + + loss = instance.call(tf, yTrue1, logits, sampleWeight1); expected = 100f; testSession.evaluate(expected, loss); } @@ -196,8 +204,9 @@ public void testNoReduction() { tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); BinaryCrossentropy instance = new BinaryCrossentropy( - tf, true, BinaryCrossentropy.LABEL_SMOOTHING_DEFAULT, Reduction.NONE); - Operand loss = instance.call(yTrue, logits); + true, BinaryCrossentropy.LABEL_SMOOTHING_DEFAULT, Reduction.NONE); + + Operand loss = instance.call(tf, yTrue, logits); Float[] expected = {0.f, 66.666664f}; testSession.evaluate(expected, loss); } @@ -215,8 +224,9 @@ public void testLabelSmoothing() { Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(1, 3))); - BinaryCrossentropy instance = new BinaryCrossentropy(tf, true, labelSmoothing); - Operand loss = instance.call(yTrue, logits); + BinaryCrossentropy instance = new BinaryCrossentropy(true, labelSmoothing); + + Operand loss = instance.call(tf, yTrue, logits); float expected = (100.0f + 50.0f * labelSmoothing) / 3.0f; testSession.evaluate(expected, loss); } catch (Exception expected) { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java index 13b287de3cd..3f6453b756a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java @@ -48,8 +48,9 @@ public void testAllCorrectUnweighted() { }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); - Operand loss = instance.call(yTrue, yPred); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0F; testSession.evaluate(expected, loss); @@ -62,8 +63,9 @@ public void testAllCorrectUnweighted() { yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new CategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits); + instance = new CategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits); testSession.setEpsilon(1e-3F); testSession.evaluate(0.0F, loss); } @@ -81,7 +83,8 @@ public void testInvalidPredictionsRange() { catchClass, () -> { Ops tf = testSession.getTF(); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + float[] trueArray = { 1L, 0L, 0L, 0L, 1L, 0L, @@ -97,7 +100,7 @@ public void testInvalidPredictionsRange() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -109,7 +112,8 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; float[] predArray = { .9F, .05F, .05F, @@ -118,7 +122,7 @@ public void testUnweighted() { }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.32396814F; testSession.evaluate(expected, loss); @@ -130,8 +134,9 @@ public void testUnweighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new CategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits); + instance = new CategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits); expected = 0.0573755F; testSession.evaluate(expected, loss); } @@ -158,8 +163,9 @@ public void testScalarWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.constant(2.3F); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = .7451267F; testSession.evaluate(expected, loss); @@ -171,8 +177,9 @@ public void testScalarWeighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new CategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits, sampleWeight); + instance = new CategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits, sampleWeight); expected = 0.13196386F; testSession.evaluate(expected, loss); } @@ -183,7 +190,8 @@ public void testSsampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + float[] sampeWeightArray = {1.2F, 3.4F, 5.6F}; int[] trueArray = { 1, 0, 0, @@ -199,7 +207,7 @@ public void testSsampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampeWeightArray), tf.constant(Shape.of(3, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.0696F; testSession.evaluate(expected, loss); @@ -211,8 +219,9 @@ public void testSsampleWeighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new CategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits, sampleWeight); + instance = new CategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits, sampleWeight); expected = 0.31829F; testSession.evaluate(expected, loss); } @@ -234,9 +243,9 @@ public void testNoReduction() { Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - CategoricalCrossentropy instance = - new CategoricalCrossentropy(tf, true, 0.0F, Reduction.NONE); - Operand loss = instance.call(yTrue, logits); + CategoricalCrossentropy instance = new CategoricalCrossentropy(true, 0.0F, Reduction.NONE); + + Operand loss = instance.call(tf, yTrue, logits); Float[] expected = {0.001822F, 0.000459F, 0.169846F}; testSession.evaluate(expected, loss); } @@ -254,8 +263,9 @@ public void testLabelSmoothing() { Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(1, 3))); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf, true, labelSmoothing); - Operand loss = instance.call(yTrue, logits); + CategoricalCrossentropy instance = new CategoricalCrossentropy(true, labelSmoothing); + + Operand loss = instance.call(tf, yTrue, logits); float expected = 400.0F * labelSmoothing / 3.0F; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java index b0d0442b3c7..d00f5374d61 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java @@ -31,12 +31,13 @@ public void testReductionNone() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf, Reduction.NONE); + CategoricalHinge instance = new CategoricalHinge(Reduction.NONE); + int[] trueArray = {1, 9, 2, -5}; float[] predArray = {4f, 8f, 12f, 8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); Float[] expected = {0.0f, 65.0f}; testSession.evaluate(expected, loss); } @@ -48,12 +49,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5}; float[] predArray = {4f, 8f, 12f, 8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 32.5f; testSession.evaluate(expected, loss); } @@ -65,17 +67,18 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 83.95f; testSession.evaluate(expected, loss); - Operand loss2 = instance.call(yTrue, yPred, sampleWeight); + Operand loss2 = instance.call(tf, yTrue, yPred, sampleWeight); testSession.evaluate(loss, loss2); } } @@ -85,7 +88,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] weightsNp = {1.2f, 3.4f}; @@ -93,7 +97,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(weightsNp), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 124.1f; testSession.evaluate(expected, loss); } @@ -104,13 +108,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -121,7 +126,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] weightsNp = {3, 6, 5, 0, 4, 2}; @@ -130,7 +136,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(weightsNp), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 4.0f; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java index 8350d1403ed..2f21929a969 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java @@ -33,11 +33,12 @@ public void testReductionNone() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - CosineSimilarity instance = new CosineSimilarity(tf, Reduction.NONE); + CosineSimilarity instance = new CosineSimilarity(Reduction.NONE); + Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); Float[] expected = {-0.720488f, 0.3460499f}; testSession.evaluate(expected, loss); } @@ -52,11 +53,12 @@ public void testUnweighted() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] expectedLoss = {0.720488f, -0.3460499f}; - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = -mean(expectedLoss); testSession.evaluate(expected, loss); } @@ -71,12 +73,13 @@ public void testScalarWeighted() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] expectedLoss = {0.720488f, -0.3460499f}; - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -mean(mul(expectedLoss, 2.3f)); testSession.evaluate(expected, loss); } @@ -90,14 +93,15 @@ public void testSampleWeighted() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] expectedLoss = {0.720488f, -0.3460499f}; - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + float[] weightsArray = {1.2f, 3.4f}; Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); Operand sampleWeight = tf.reshape(tf.constant(weightsArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -mean(mul(expectedLoss, weightsArray)); testSession.evaluate(expected, loss); } @@ -108,14 +112,15 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); Operand sampleWeight = tf.constant(0f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -128,14 +133,15 @@ public void testTimestepWeighted() { Ops tf = testSession.getTF(); float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + Shape shape = Shape.of(2, 3, 1); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); float[] weightsArray = {3, 6, 5, 0, 4, 2}; Operand sampleWeight = tf.reshape(tf.constant(weightsArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -2.0f; testSession.evaluate(expected, loss); } @@ -149,11 +155,12 @@ public void testAxis() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] expectedLoss = {0.720488f, -0.3460499f}; - CosineSimilarity instance = new CosineSimilarity(tf, 1); + CosineSimilarity instance = new CosineSimilarity(1); + Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = -mean(expectedLoss); testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java index 4770511207e..d5fe846c82e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java @@ -33,12 +33,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.50625f; testSession.evaluate(expected, loss); } @@ -56,14 +57,15 @@ public void testInvalidLabelValue() { catchClass, () -> { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] trueArray = {2f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -75,13 +77,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.164375f; testSession.evaluate(expected, loss); @@ -94,7 +97,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] sampleArray = {1.2f, 3.4f}; float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; @@ -102,7 +106,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.06125f; testSession.evaluate(expected, loss); } @@ -113,13 +117,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -130,7 +135,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf, Reduction.AUTO); + Hinge instance = new Hinge(Reduction.AUTO); + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f, 1f, 3f}; float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; @@ -140,7 +146,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 2.0125f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java index d1751f223a1..86a71e5ecbb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java @@ -32,8 +32,9 @@ public void testAllCorrect() { float[] trueArray = {.9f, .2f, .2f, .8f, .4f, .6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); - Operand loss = instance.call(yTrue, yTrue); + Huber instance = new Huber(); + + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -50,8 +51,9 @@ public void testUnweighted() { float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); - Operand loss = instance.call(yTrue, yPred); + Huber instance = new Huber(); + + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.10416666666666669f; testSession.evaluate(expected, loss); } @@ -67,9 +69,10 @@ public void testScalarWeighted() { float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); + Huber instance = new Huber(); + Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.23958333333333337f; testSession.evaluate(expected, loss); @@ -87,10 +90,11 @@ public void testSampleWeighted() { float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); + Huber instance = new Huber(); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.22766666666666668f; testSession.evaluate(expected, loss); } @@ -105,9 +109,10 @@ public void testZeroWeighted() { float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); + Huber instance = new Huber(); + Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -125,10 +130,11 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); - Huber instance = new Huber(tf); + Huber instance = new Huber(); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = .4025f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java index d57b61b18dd..1d7ee87b920 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java @@ -30,12 +30,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); + KLDivergence instance = new KLDivergence(); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.5960738398643668f; testSession.evaluate(expected, loss); } @@ -47,13 +48,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); + KLDivergence instance = new KLDivergence(); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.3709698316880434f; testSession.evaluate(expected, loss); } @@ -64,7 +66,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); + KLDivergence instance = new KLDivergence(); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; float[] sampleArray = {1.2f, 3.4f}; @@ -72,7 +75,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 2.0075711736936492f; testSession.evaluate(expected, loss); } @@ -83,13 +86,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); + KLDivergence instance = new KLDivergence(); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -100,7 +104,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf, Reduction.AUTO); + KLDivergence instance = new KLDivergence(Reduction.AUTO); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -110,7 +115,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.2495994912084345f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java index c4347b3fccb..ce6782cee3b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java @@ -30,12 +30,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); + LogCosh instance = new LogCosh(); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 4.829245330860459f; testSession.evaluate(expected, loss); } @@ -47,13 +48,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); + LogCosh instance = new LogCosh(); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 11.107264260979056f; testSession.evaluate(expected, loss); } @@ -64,7 +66,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); + LogCosh instance = new LogCosh(); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -72,7 +75,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 12.001114667519486f; testSession.evaluate(expected, loss); } @@ -83,13 +86,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); + LogCosh instance = new LogCosh(); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -100,7 +104,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf, Reduction.AUTO); + LogCosh instance = new LogCosh(Reduction.AUTO); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -110,7 +115,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 11.653484271934046f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java index 3498c6d53aa..cbcb2c37391 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java @@ -31,10 +31,11 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -46,12 +47,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 5.5f; testSession.evaluate(expected, loss); } @@ -63,13 +65,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 12.65f; testSession.evaluate(expected, loss); } @@ -80,7 +83,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -88,7 +92,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 81.4f / 6f; testSession.evaluate(expected, loss); } @@ -99,13 +103,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -116,7 +121,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf, Reduction.AUTO); + MeanAbsoluteError instance = new MeanAbsoluteError(Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -126,7 +132,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 83f / 6f; testSession.evaluate(expected, loss); @@ -141,7 +147,8 @@ public void testInvalidSampleWeight() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f}; @@ -151,7 +158,7 @@ public void testInvalidSampleWeight() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 83f / 6f; testSession.evaluate(expected, loss); } @@ -163,13 +170,14 @@ public void testNoReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf, Reduction.NONE); + MeanAbsoluteError instance = new MeanAbsoluteError(Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {10.733333f, 14.566667f}; testSession.evaluate(expected, loss); } @@ -180,13 +188,14 @@ public void testSumReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf, Reduction.SUM); + MeanAbsoluteError instance = new MeanAbsoluteError(Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {25.29999f}; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java index 7816a8a288a..b521f2f5644 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java @@ -30,10 +30,11 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -45,12 +46,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 211.85184f; testSession.evaluate(expected, loss); } @@ -62,13 +64,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 487.25922f; testSession.evaluate(expected, loss); } @@ -79,7 +82,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -87,7 +91,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 422.8889f; testSession.evaluate(expected, loss); } @@ -98,13 +102,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -115,7 +120,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf, Reduction.AUTO); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -125,7 +131,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 694.4445f; testSession.evaluate(expected, loss); } @@ -136,13 +142,14 @@ public void testNoReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf, Reduction.NONE); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {621.8518f, 352.66666f}; testSession.evaluate(expected, loss); } @@ -153,13 +160,14 @@ public void testSumReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf, Reduction.SUM); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 974.51843f; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java index 1a971f0492b..e9fd0d7e349 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java @@ -31,10 +31,11 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -46,12 +47,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 49.5f; testSession.evaluate(expected, loss); } @@ -63,13 +65,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 113.85f; testSession.evaluate(expected, loss); } @@ -80,7 +83,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -88,7 +92,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 127.96667f; testSession.evaluate(expected, loss); } @@ -99,13 +103,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -116,7 +121,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf, Reduction.AUTO); + MeanSquaredError instance = new MeanSquaredError(Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -126,7 +132,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 97.833336f; testSession.evaluate(expected, loss); @@ -141,7 +147,8 @@ public void testInvalidSampleWeight() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f}; @@ -151,7 +158,7 @@ public void testInvalidSampleWeight() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 173.25f; testSession.evaluate(expected, loss); } @@ -163,13 +170,14 @@ public void testNoReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf, Reduction.NONE); + MeanSquaredError instance = new MeanSquaredError(Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {84.333336f, 143.36665f}; testSession.evaluate(expected, loss); } @@ -180,13 +188,14 @@ public void testSumReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf, Reduction.SUM); + MeanSquaredError instance = new MeanSquaredError(Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {227.69998f}; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java index 558f9c84659..0c6d411c53f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java @@ -31,10 +31,11 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -46,12 +47,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 1.4370421f; testSession.evaluate(expected, loss); } @@ -63,13 +65,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 3.3051968f; testSession.evaluate(expected, loss); } @@ -80,7 +83,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -88,7 +92,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 3.7856376f; testSession.evaluate(expected, loss); } @@ -99,13 +103,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -116,7 +121,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.AUTO); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -126,7 +132,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 2.647374f; testSession.evaluate(expected, loss); @@ -141,7 +147,8 @@ public void testInvalidSampleWeight() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f}; @@ -151,7 +158,7 @@ public void testInvalidSampleWeight() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 83f / 6f; testSession.evaluate(expected, loss); } @@ -163,13 +170,14 @@ public void testNoReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.NONE); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {2.3006392f, 4.3097544f}; testSession.evaluate(expected, loss); } @@ -180,13 +188,14 @@ public void testSumReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.SUM); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {6.6103935f}; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java index 55c59ca5ac6..c354c83bfe2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java @@ -30,12 +30,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); + Poisson instance = new Poisson(); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = -3.306581945521002f; testSession.evaluate(expected, loss); } @@ -47,13 +48,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); + Poisson instance = new Poisson(); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -7.605138474698304f; testSession.evaluate(expected, loss); } @@ -64,7 +66,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); + Poisson instance = new Poisson(); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -72,7 +75,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -6.147338926788071f; testSession.evaluate(expected, loss); } @@ -83,13 +86,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); + Poisson instance = new Poisson(); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -100,7 +104,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf, Reduction.AUTO); + Poisson instance = new Poisson(Reduction.AUTO); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -110,7 +115,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -12.263126013890561f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java index a6a0ff35c78..113b89b82ff 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java @@ -44,8 +44,9 @@ public void testAllCorrectUnweighted() { }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); - Operand loss = instance.call(yTrue, yPred); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.0f; testSession.evaluate(expected, loss); @@ -57,8 +58,9 @@ public void testAllCorrectUnweighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new SparseCategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits); + instance = new SparseCategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits); testSession.evaluate(0.0f, loss); } } @@ -75,7 +77,8 @@ public void testInvalidPredictionsRange() { catchClass, () -> { Ops tf = testSession.getTF(); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + int[] trueArray = {0, 1, 2}; float[] predArray = { 1.9f, .05f, .05f, @@ -86,7 +89,7 @@ public void testInvalidPredictionsRange() { tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -98,7 +101,8 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + int[] trueArray = {0, 1, 2}; float[] predArray = { .9f, .05f, .05f, @@ -107,7 +111,7 @@ public void testUnweighted() { }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.32396814f; testSession.evaluate(expected, loss); @@ -119,8 +123,9 @@ public void testUnweighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new SparseCategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits); + instance = new SparseCategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits); expected = 0.05737559f; testSession.evaluate(expected, loss); } @@ -143,8 +148,9 @@ public void testScalarWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.constant(2.3f); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = .7451267f; testSession.evaluate(expected, loss); @@ -156,8 +162,9 @@ public void testScalarWeighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new SparseCategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits, sampleWeight); + instance = new SparseCategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits, sampleWeight); expected = 0.13196386f; testSession.evaluate(expected, loss); } @@ -168,7 +175,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + float[] sampleWeightArray = {1.2f, 3.4f, 5.6f}; int[] trueArray = {0, 1, 2}; float[] predArray = { @@ -180,7 +188,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleWeightArray), tf.constant(Shape.of(3, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.0696f; testSession.evaluate(expected, loss); @@ -192,8 +200,9 @@ public void testSampleWeighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new SparseCategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits, sampleWeight); + instance = new SparseCategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits, sampleWeight); expected = 0.31829f; testSession.evaluate(expected, loss); } @@ -216,8 +225,9 @@ public void testNoReduction() { Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); SparseCategoricalCrossentropy instance = - new SparseCategoricalCrossentropy(tf, true, Reduction.NONE); - Operand loss = instance.call(yTrue, logits); + new SparseCategoricalCrossentropy(true, Reduction.NONE); + + Operand loss = instance.call(tf, yTrue, logits); Float[] expected = {0.001822f, 0.000459f, 0.169846f}; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java index 57a012bbe9d..979e778e4c3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java @@ -32,12 +32,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.364062f; testSession.evaluate(expected, loss); } @@ -55,14 +56,15 @@ public void testInvalidLabelValue() { catchClass, () -> { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] trueArray = {0, 2, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -74,13 +76,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.8373437f; testSession.evaluate(expected, loss); } @@ -91,7 +94,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] sampleArray = {1.2f, 3.4f}; float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; @@ -99,7 +103,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.7043125f; testSession.evaluate(expected, loss); } @@ -110,13 +114,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -127,7 +132,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf, Reduction.AUTO); + SquaredHinge instance = new SquaredHinge(Reduction.AUTO); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = @@ -137,7 +143,7 @@ public void testTimestepWeighted() { float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f, 1f, 3f}; Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.54250000f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index d6786b71972..d957cfb2508 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -1,13 +1,17 @@ package org.tensorflow.framework.optimizers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.framework.initializers.Glorot; import org.tensorflow.framework.initializers.VarianceScaling; import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.op.Op; @@ -26,10 +30,8 @@ import org.tensorflow.types.family.TType; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; /** Test cases for GradientDescent Optimizer */ @@ -136,14 +138,14 @@ public void testDeterminism() { Ops tf = Ops.create(g); Glorot initializer = - new Glorot<>(tf, VarianceScaling.Distribution.TRUNCATED_NORMAL, 1L); + new Glorot<>(VarianceScaling.Distribution.TRUNCATED_NORMAL, 1L); // Inputs Placeholder input = tf.withName("input").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 20))); // Fully connected layer Variable fcWeights = - tf.variable(initializer.call(tf.array(20L, 200L), TFloat32.class)); + tf.variable(initializer.call(tf, tf.array(20L, 200L), TFloat32.class)); fcWeightName = fcWeights.op().name(); Variable fcBiases = tf.variable(tf.fill(tf.array(200), tf.constant(0.1f))); fcBiasName = fcBiases.op().name(); @@ -151,13 +153,13 @@ public void testDeterminism() { // Output layer Variable outputWeights = - tf.variable(initializer.call(tf.array(200L, 2L), TFloat32.class)); + tf.variable(initializer.call(tf, tf.array(200L, 2L), TFloat32.class)); outputWeightName = outputWeights.op().name(); Variable outputBiases = tf.variable(tf.fill(tf.array(2L), tf.constant(0.1f))); outputBiasName = outputBiases.op().name(); Add output = tf.math.add(tf.linalg.matMul(relu, outputWeights), outputBiases); - // Loss + // AbstractLoss Placeholder placeholder = tf.withName("output").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 2))); Mean loss = @@ -205,12 +207,15 @@ public void testDeterminism() { .fetch(outputBiasName) .run()); - TFloat32 lossVal = (TFloat32) s.runner() - .addTarget(trainName) - .feed("input", dataTensor) - .feed("output", targetTensor) - .fetch(lossName) - .run().get(0); + TFloat32 lossVal = + (TFloat32) + s.runner() + .addTarget(trainName) + .feed("input", dataTensor) + .feed("output", targetTensor) + .fetch(lossName) + .run() + .get(0); initialLoss[i] = lossVal.getFloat(); lossVal.close(); @@ -222,12 +227,15 @@ public void testDeterminism() { .fetch(outputBiasName) .run()); - lossVal = (TFloat32) s.runner() - .addTarget(trainName) - .feed("input", dataTensor) - .feed("output", targetTensor) - .fetch(lossName) - .run().get(0); + lossVal = + (TFloat32) + s.runner() + .addTarget(trainName) + .feed("input", dataTensor) + .feed("output", targetTensor) + .fetch(lossName) + .run() + .get(0); postTrainingLoss[i] = lossVal.getFloat(); lossVal.close(); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java index 181ae367f07..a4b98c002cb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -17,25 +17,25 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.2f, 0.3f); + L1L2 instance = new L1L2(0.2f, 0.3f); assertEquals(0.2f, instance.getL1()); assertEquals(0.3f, instance.getL2()); - instance = new L1L2(tf, 0, 0); + instance = new L1L2(0, 0); assertEquals(0.f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2(tf, 0.5f, 0); + instance = new L1L2(0.5f, 0); assertEquals(0.5f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2(tf, 0, 0.5f); + instance = new L1L2(0, 0.5f); assertEquals(0.f, instance.getL1()); assertEquals(0.5f, instance.getL2()); - instance = new L1L2(tf); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); + instance = new L1L2(); + assertEquals(AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); } } @@ -44,8 +44,8 @@ public void testCallDefaultsConstant() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf); - Operand result = instance.call(tf.constant(555f)); + L1L2 instance = new L1L2(); + Operand result = instance.call(tf, tf.constant(555f)); session.evaluate(3085.8f, result); } } @@ -55,10 +55,10 @@ public void testCallL1L2_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0, 0); + L1L2 instance = new L1L2(0, 0); Operand weights = tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); session.evaluate(0, result); } } @@ -68,10 +68,10 @@ public void testCallL1L2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.01f, 0.02f); + L1L2 instance = new L1L2(0.01f, 0.02f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); float expected = regularizeL1L2(w, 0.01f, 0.02f); session.setEpsilon(.09f); session.evaluate(expected, result); @@ -83,10 +83,10 @@ public void testCallL1L2TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.01f, 0.02f); + L1L2 instance = new L1L2(0.01f, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); double expected = regularizeL1L2(w, 0.01f, 0.02f); session.setEpsilon(.09f); session.evaluate(expected, result); @@ -98,10 +98,10 @@ public void testCallL2_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.01f, 0); + L1L2 instance = new L1L2(0.01f, 0); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); float expected = regularizeL1(w, 0.01f); session.evaluate(expected, result); } @@ -112,10 +112,10 @@ public void testCallL1_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0, 0.02f); + L1L2 instance = new L1L2(0, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); double expected = regularizeL2(w, 0.02f); session.setEpsilon(.01f); session.evaluate(expected, result); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java index 0e42a257816..f7d540fb8e1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java @@ -17,16 +17,16 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1(tf, 0.2f); + L1 instance = new L1(0.2f); assertEquals(0.2f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1(tf, 0f); + instance = new L1(0f); assertEquals(0.f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1(tf); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + instance = new L1(); + assertEquals(AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); assertEquals(0.f, instance.getL2()); } } @@ -36,10 +36,10 @@ public void testCallL10() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1(tf, 0.0f); + L1 instance = new L1(0.0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); session.evaluate(0f, result); } } @@ -49,11 +49,11 @@ public void testCallL1TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1(tf); + L1 instance = new L1(); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = regularizeL1(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + Operand result = instance.call(tf, weights); + float expected = regularizeL1(w, AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY); session.evaluate(expected, result); } } @@ -63,10 +63,10 @@ public void testCallL1TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1(tf, 0.02f); + L1 instance = new L1(0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); double expected = regularizeL1(w, 0.02f); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java index aba036ee306..4579ccaf551 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java @@ -17,16 +17,16 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2(tf, 0.2f); + L2 instance = new L2(0.2f); assertEquals(0.2f, instance.getL2()); assertEquals(0.f, instance.getL1()); - instance = new L2(tf, 0f); + instance = new L2(0f); assertEquals(0.f, instance.getL2()); assertEquals(0.f, instance.getL1()); - L2 instance64 = new L2(tf); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2()); + L2 instance64 = new L2(); + assertEquals(AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2()); assertEquals(0.f, instance64.getL1()); } } @@ -36,10 +36,10 @@ public void testCallL20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2(tf, 0.0f); + L2 instance = new L2(0.0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); session.evaluate(0, result); } } @@ -49,11 +49,11 @@ public void testCallL2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2(tf); + L2 instance = new L2(); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = regularizeL2(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + Operand result = instance.call(tf, weights); + float expected = regularizeL2(w, AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY); session.setEpsilon(.01f); session.evaluate(expected, result); } @@ -64,10 +64,10 @@ public void testCallL2TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2(tf, 0.02f); + L2 instance = new L2(0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); double expected = regularizeL2(w, 0.02f); session.setEpsilon(.01f); session.evaluate(expected, result); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java index fe2624cec3d..6918f631e6a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java @@ -14,13 +14,13 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 regularizer = new L1L2(tf, 0.01f, 0f); + L1L2 regularizer = new L1L2(0.01f, 0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand regularizerResult = regularizer.call(weights); - RegularizerLoss lossInstance = new RegularizerLoss(tf, regularizer); + Operand regularizerResult = regularizer.call(tf, weights); + RegularizerLoss lossInstance = new RegularizerLoss(regularizer); - Operand loss = lossInstance.call(null, null, weights); + Operand loss = lossInstance.call(tf, null, null, weights); session.evaluate(regularizerResult, loss); } } From 4eb231bb24f93d230ef51b4f613cf2444d4dac0c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 1 Jun 2021 11:18:31 -0400 Subject: [PATCH 3/5] Move Ops from CTOR to call method --- .../activations/AbstractActivation.java | 46 ++++++ .../framework/activations/Activation.java | 45 +----- .../tensorflow/framework/activations/ELU.java | 34 ++--- .../framework/activations/Exponential.java | 23 +-- .../framework/activations/HardSigmoid.java | 31 ++-- .../framework/activations/Linear.java | 18 +-- .../framework/activations/ReLU.java | 30 ++-- .../framework/activations/SELU.java | 21 +-- .../framework/activations/Sigmoid.java | 21 +-- .../framework/activations/Softmax.java | 22 +-- .../framework/activations/Softplus.java | 21 +-- .../framework/activations/Softsign.java | 21 +-- .../framework/activations/Swish.java | 11 +- .../framework/activations/Tanh.java | 14 +- .../constraints/AbstractConstraint.java | 89 ++++++++++++ .../framework/constraints/Constraint.java | 88 +----------- .../framework/constraints/MaxNorm.java | 30 ++-- .../framework/constraints/MinMaxNorm.java | 30 ++-- .../framework/constraints/NonNeg.java | 15 +- .../framework/constraints/UnitNorm.java | 31 ++-- .../initializers/BaseInitializer.java | 21 ++- .../framework/initializers/Constant.java | 31 ++-- .../framework/initializers/Glorot.java | 12 +- .../tensorflow/framework/initializers/He.java | 16 +-- .../framework/initializers/Identity.java | 30 ++-- .../framework/initializers/Initializer.java | 7 +- .../framework/initializers/LeCun.java | 15 +- .../framework/initializers/Ones.java | 20 +-- .../framework/initializers/Orthogonal.java | 21 +-- .../framework/initializers/RandomNormal.java | 26 ++-- .../framework/initializers/RandomUniform.java | 31 ++-- .../initializers/TruncatedNormal.java | 23 +-- .../initializers/VarianceScaling.java | 32 ++--- .../framework/initializers/Zeros.java | 17 +-- .../framework/losses/BinaryCrossentropy.java | 79 +++++----- .../losses/CategoricalCrossentropy.java | 135 ++++++++---------- .../framework/losses/CategoricalHinge.java | 40 +++--- .../framework/losses/CosineSimilarity.java | 115 +++++++-------- .../tensorflow/framework/losses/Hinge.java | 48 +++---- .../tensorflow/framework/losses/Huber.java | 61 ++++---- .../framework/losses/KLDivergence.java | 50 ++++--- .../tensorflow/framework/losses/LogCosh.java | 54 ++++--- .../org/tensorflow/framework/losses/Loss.java | 78 +--------- .../framework/losses/MeanAbsoluteError.java | 44 +++--- .../losses/MeanAbsolutePercentageError.java | 45 +++--- .../framework/losses/MeanSquaredError.java | 44 +++--- .../losses/MeanSquaredLogarithmicError.java | 44 +++--- .../tensorflow/framework/losses/Poisson.java | 54 ++++--- .../framework/losses/Reduction.java | 2 +- .../losses/SparseCategoricalCrossentropy.java | 73 +++++----- .../framework/losses/SquaredHinge.java | 53 ++++--- .../framework/losses/impl/AbstractLoss.java | 89 ++++++++++++ .../org/tensorflow/framework/metrics/AUC.java | 95 ++++++------ .../framework/metrics/Accuracy.java | 8 +- .../framework/metrics/BinaryAccuracy.java | 8 +- .../metrics/CategoricalAccuracy.java | 19 ++- .../metrics/CategoricalCrossentropy.java | 20 ++- .../framework/metrics/FalseNegatives.java | 42 +++--- .../framework/metrics/FalsePositives.java | 42 +++--- .../tensorflow/framework/metrics/MeanIoU.java | 14 +- .../framework/metrics/MeanRelativeError.java | 11 +- .../framework/metrics/MeanTensor.java | 4 +- .../framework/metrics/Precision.java | 71 +++++---- .../framework/metrics/PrecisionAtRecall.java | 7 +- .../tensorflow/framework/metrics/Recall.java | 26 ++-- .../framework/metrics/RecallAtPrecision.java | 4 +- .../metrics/RootMeanSquaredError.java | 3 +- .../metrics/SensitivityAtSpecificity.java | 20 +-- .../metrics/SparseCategoricalAccuracy.java | 6 +- .../metrics/SpecificityAtSensitivity.java | 20 +-- .../org/tensorflow/framework/metrics/Sum.java | 8 +- .../metrics/TopKCategoricalAccuracy.java | 4 +- .../framework/metrics/TrueNegatives.java | 42 +++--- .../framework/metrics/TruePositives.java | 42 +++--- .../impl/ConfusionMatrixConditionCount.java | 26 ++-- .../framework/metrics/impl/LossMetric.java | 2 +- .../metrics/impl/MeanMetricWrapper.java | 8 +- .../framework/metrics/impl/MetricsHelper.java | 116 +++++++-------- .../impl/SensitivitySpecificityBase.java | 6 +- .../framework/metrics/impl/SetsOps.java | 24 ++-- .../framework/metrics/impl/SymbolicShape.java | 45 +++++- .../metrics/impl/WeightsBroadcastOps.java | 34 ++--- .../regularizers/AbstractRegularizer.java | 63 ++++++++ .../tensorflow/framework/regularizers/L1.java | 33 +++-- .../framework/regularizers/L1L2.java | 38 ++--- .../tensorflow/framework/regularizers/L2.java | 33 +++-- .../framework/regularizers/Regularizer.java | 67 +-------- .../regularizers/RegularizerLoss.java | 31 ++-- .../framework/activations/ELUTest.java | 33 +---- .../activations/ExponentialTest.java | 28 +--- .../activations/HardSigmoidTest.java | 28 +--- .../framework/activations/LinearTest.java | 28 +--- .../framework/activations/ReLUTest.java | 58 ++++---- .../framework/activations/SELUTest.java | 28 +--- .../framework/activations/SigmoidTest.java | 27 +--- .../framework/activations/SoftmaxTest.java | 47 ++---- .../framework/activations/SoftplusTest.java | 24 +--- .../framework/activations/SoftsignTest.java | 24 +--- .../framework/activations/SwishTest.java | 28 +--- .../framework/activations/TanhTest.java | 24 +--- .../framework/constraints/MaxNormTest.java | 8 +- .../framework/constraints/MinMaxNormTest.java | 4 +- .../framework/constraints/NonNegTest.java | 8 +- .../framework/constraints/UnitNormTest.java | 8 +- .../framework/initializers/ConstantTest.java | 66 ++++----- .../framework/initializers/GlorotTest.java | 57 ++++---- .../framework/initializers/HeTest.java | 57 ++++---- .../framework/initializers/IdentityTest.java | 34 ++--- .../framework/initializers/LeCunTest.java | 50 +++---- .../framework/initializers/OnesTest.java | 72 +++++----- .../initializers/OrthogonalTest.java | 34 ++--- .../initializers/RandomNormalTest.java | 33 ++--- .../initializers/RandomUniformTest.java | 38 ++--- .../initializers/TruncatedNormalTest.java | 33 ++--- .../initializers/VarianceScalingTest.java | 73 +++------- .../framework/initializers/ZerosTest.java | 72 +++++----- .../losses/BinaryCrossentropyTest.java | 54 ++++--- .../losses/CategoricalCrossentropyTest.java | 56 +++++--- .../losses/CategoricalHingeTest.java | 32 +++-- .../losses/CosineSimilarityTest.java | 35 +++-- .../framework/losses/HingeTest.java | 30 ++-- .../framework/losses/HuberTest.java | 30 ++-- .../framework/losses/KLDivergenceTest.java | 25 ++-- .../framework/losses/LogCoshTest.java | 25 ++-- .../losses/MeanAbsoluteErrorTest.java | 45 +++--- .../MeanAbsolutePercentageErrorTest.java | 40 +++--- .../losses/MeanSquaredErrorTest.java | 45 +++--- .../MeanSquaredLogarithmicErrorTest.java | 45 +++--- .../framework/losses/PoissonTest.java | 25 ++-- .../SparseCategoricalCrossentropyTest.java | 50 ++++--- .../framework/losses/SquaredHingeTest.java | 30 ++-- .../optimizers/GradientDescentTest.java | 48 ++++--- .../framework/regularizers/L1L2Test.java | 38 ++--- .../framework/regularizers/L1Test.java | 22 +-- .../framework/regularizers/L2Test.java | 22 +-- .../regularizers/RegularizerLossTest.java | 8 +- 136 files changed, 2278 insertions(+), 2544 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/activations/AbstractActivation.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/AbstractLoss.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/AbstractRegularizer.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/AbstractActivation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/AbstractActivation.java new file mode 100644 index 00000000000..335b8697273 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/AbstractActivation.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Abstract base class for Activations */ +public abstract class AbstractActivation implements Activation { + + /** The TensorFlow Ops */ + protected Ops tf; + + /** Creates the abstract class for an AbstractActivation */ + protected AbstractActivation() {} + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + protected Ops getTF() { + return this.tf; + } + + /** + * Sets the TensorFlow Ops + * + * @param tf the TensorFlow Ops + */ + protected void setTF(Ops tf) { + this.tf = tf; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java index e1482a51a8a..f73c6678ab3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,50 +19,19 @@ import org.tensorflow.types.family.TNumber; /** - * Abstract base class for Activations + * Interface for Activations * - *

      Note: The {@link #tf} attribute must be set prior to invoking the call method. See - * {@link #setTF(Ops)} and the constructor {@link #Activation(Ops)}. - * - * @param the data type of the activation + * @param the data type of the input and the result */ -public abstract class Activation { - - /** The TensorFlow Ops */ - protected Ops tf; - - /** - * Creates the abstract class for an Activation - * - * @param tf the TensorFlow Ops - */ - protected Activation(Ops tf) { - this.tf = tf; - } - - /** - * Sets the TensorFlow Ops - * - * @param tf the TensorFlow Ops - */ - protected void setTF(Ops tf) { - this.tf = tf; - } - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - protected Ops getTF() { - return this.tf; - } +@FunctionalInterface +public interface Activation { /** * Gets the calculation operation for the activation. * + * @param tf the TensorFlow Ops * @param input the input tensor * @return The operand for the activation */ - public abstract Operand call(Operand input); + Operand call(Ops tf, Operand input); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java index 2f2f16f2752..919a947a127 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java @@ -19,6 +19,8 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Exponential linear unit. * @@ -44,53 +46,41 @@ * Operand<TFloat32> result = elu.call(input); * * - * @param the data type of the activation * @see Clevert et al, 2016, Fast and Accurate Deep * Network Learning by Exponential Linear Units (ELUs) */ -public class ELU extends Activation { +public class ELU extends AbstractActivation { private static final double ALPHA_DEFAULT = 1.0; /** A scalar, slope of negative section. */ private final double alpha; - /** - * Creates a new ELU with alpha={@link #ALPHA_DEFAULT}. - * - * @param tf the TensorFlow Ops - */ - public ELU(Ops tf) { - this(tf, ALPHA_DEFAULT); + /** Creates a new ELU with alpha={@link #ALPHA_DEFAULT}. */ + public ELU() { + this(ALPHA_DEFAULT); } /** * Creates a new ELU * - * @param tf the TensorFlow Ops * @param alpha A scalar, slope of negative section. It controls the value to which an ELU * saturates for negative net inputs. */ - public ELU(Ops tf, double alpha) { - super(tf); + public ELU(double alpha) { + super(); this.alpha = alpha; } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { - + public Operand call(Ops tf, Operand input) { Operand result = tf.nn.elu(input); if (alpha == 1.0) return result; else { Class inputType = input.type(); - Operand y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), inputType)); - Operand cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), inputType)); + Operand y = tf.math.mul(result, cast(tf, tf.constant(alpha), inputType)); + Operand cond = tf.math.greater(result, cast(tf, tf.constant(0), inputType)); return tf.select(cond, result, y); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java index d5fdff36c61..8398ada6362 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java @@ -30,28 +30,17 @@ * Operand<TFloat32> result = exp.call(input); * // result is [0.04978707f, 0.36787945f, 1.f, 2.7182817f, 20.085537f] * - * - * @param the data type of the activation */ -public class Exponential extends Activation { +public class Exponential extends AbstractActivation { - /** - * Creates an Exponential activation. - * - * @param tf the TensorFlow Ops - */ - public Exponential(Ops tf) { - super(tf); + /** Creates an Exponential activation. */ + public Exponential() { + super(); } - /** - * Calculates the Exponential activation. - * - * @param input the input tensor - * @return an Operand for the exponential activation: exp(x). - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.math.exp(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java index 0b7cf573b8e..fac4d14eca5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java @@ -18,6 +18,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Hard sigmoid activation. * @@ -40,34 +42,23 @@ * Operand<TFloat32> result = hardSigmoid.call(input); * // result is [0.f , 0.3f, 0.5f, 0.7f, 1.f] * - * - * @param the data type of the result */ -public class HardSigmoid extends Activation { +public class HardSigmoid extends AbstractActivation { - /** - * Creates Hard sigmoid activation. - * - * @param tf the TensorFlow Ops - */ - public HardSigmoid(Ops tf) { - super(tf); + /** Creates Hard sigmoid activation. */ + public HardSigmoid() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { Class inputType = input.type(); - Operand point2 = tf.dtypes.cast(tf.constant(0.2), inputType); - Operand point5 = tf.dtypes.cast(tf.constant(0.5), inputType); + Operand point2 = cast(tf, tf.constant(0.2), inputType); + Operand point5 = cast(tf, tf.constant(0.5), inputType); Operand x = tf.math.add(tf.math.mul(input, point2), point5); return tf.clipByValue( - x, tf.dtypes.cast(tf.constant(0), inputType), tf.dtypes.cast(tf.constant(1), inputType)); + x, cast(tf, tf.constant(0), inputType), cast(tf, tf.constant(1), inputType)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java index d907397995d..d1a5eede616 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java @@ -19,9 +19,9 @@ import org.tensorflow.types.family.TNumber; /** - * Linear activation function (pass-through). + * Linear activation function (pass-through). * - *

      The linear activation returns its input. It is also known as the Identity activation function.

      + *

      The linear activation returns its input. It is also known as the Identity activation function. * *

      For example: * @@ -33,20 +33,16 @@ * // result is [-3.0f,-1.0f, 0.0f,1.0f,3.0f] * */ -public class Linear extends Activation { +public class Linear extends AbstractActivation { - /** - * Creates a linear activation. - * - * @param tf the TensorFlow Ops - */ - public Linear(Ops tf) { - super(tf); + /** Creates a linear activation. */ + public Linear() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return input; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java index aef6ebf2992..c966e5d9ddd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java @@ -20,6 +20,8 @@ import org.tensorflow.op.nn.LeakyRelu; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Rectified Linear Unit(ReLU) activation. * @@ -58,7 +60,7 @@ * * @param the data type of the result */ -public class ReLU extends Activation { +public class ReLU extends AbstractActivation { public static final float ALPHA_DEFAULT = 0.0f; public static final float MAX_VALUE_DEFAULT = Float.NaN; @@ -71,24 +73,21 @@ public class ReLU extends Activation { /** * Creates a new ReLU with alpha={@link #ALPHA_DEFAULT}, maxValue={@link #MAX_VALUE_DEFAULT}, * threshold={@link #THRESHOLD_DEFAULT}, - * - * @param tf the TensorFlow Ops */ - public ReLU(Ops tf) { - this(tf, ALPHA_DEFAULT, MAX_VALUE_DEFAULT, THRESHOLD_DEFAULT); + public ReLU() { + this(ALPHA_DEFAULT, MAX_VALUE_DEFAULT, THRESHOLD_DEFAULT); } /** * Creates a new ReLU * - * @param tf the TensorFlow Ops * @param alpha governs the slope for values lower than the threshold. * @param maxValue sets the saturation threshold (the largest value the function will return). * @param threshold the threshold value of the activation function below which values will be * damped or set to zero. */ - public ReLU(Ops tf, float alpha, float maxValue, float threshold) { - super(tf); + public ReLU(float alpha, float maxValue, float threshold) { + super(); this.alpha = alpha; this.maxValue = maxValue; this.threshold = threshold; @@ -96,7 +95,7 @@ public ReLU(Ops tf, float alpha, float maxValue, float threshold) { /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { Class inputType = input.type(); boolean clipMax = !Float.isNaN(maxValue); @@ -108,7 +107,7 @@ public Operand call(Operand input) { if (threshold != 0) { negativePart = tf.nn.relu( - tf.math.add(tf.math.neg(input), tf.dtypes.cast(tf.constant(threshold), inputType))); + tf.math.add(tf.math.neg(input), cast(tf, tf.constant(threshold), inputType))); } else { negativePart = tf.nn.relu(tf.math.neg(input)); } @@ -117,8 +116,8 @@ public Operand call(Operand input) { Operand lInput; if (threshold != 0) { // computes input for input > threshold else 0 - Greater greater = tf.math.greater(input, tf.dtypes.cast(tf.constant(threshold), inputType)); - lInput = tf.math.mul(input, tf.dtypes.cast(greater, inputType)); + Greater greater = tf.math.greater(input, cast(tf, tf.constant(threshold), inputType)); + lInput = tf.math.mul(input, cast(tf, greater, inputType)); } else if (maxValue == 6) { // if no threshold, then can use nn.relu6 native TF op for performance lInput = tf.nn.relu6(input); @@ -127,15 +126,14 @@ public Operand call(Operand input) { lInput = tf.nn.relu(input); } if (clipMax) { - Operand lmaxValue = tf.dtypes.cast(tf.constant(maxValue), inputType); - Operand zero = tf.dtypes.cast(tf.constant(0), inputType); + Operand lmaxValue = cast(tf, tf.constant(maxValue), inputType); + Operand zero = cast(tf, tf.constant(0), inputType); lInput = tf.clipByValue(lInput, zero, lmaxValue); } if (alpha != 0.) { lInput = - tf.math.sub( - lInput, tf.math.mul(tf.dtypes.cast(tf.constant(alpha), inputType), negativePart)); + tf.math.sub(lInput, tf.math.mul(cast(tf, tf.constant(alpha), inputType), negativePart)); } return lInput; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java index f24731049fb..a28052486e5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java @@ -45,25 +45,16 @@ * @param the data type of the activation * @see Klambauer et al., 2017 */ -public class SELU extends Activation { +public class SELU extends AbstractActivation { - /** - * Creates a Scaled Exponential Linear Unit (SELU) activation. - * - * @param tf the TensorFlow Ops - */ - public SELU(Ops tf) { - super(tf); + /** Creates a Scaled Exponential Linear Unit (SELU) activation. */ + public SELU() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.nn.selu(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java index 5d507b38483..02b2daae4d6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java @@ -41,25 +41,16 @@ * * @param the data type of the activation */ -public class Sigmoid extends Activation { +public class Sigmoid extends AbstractActivation { - /** - * Creates a Sigmoid activation. - * - * @param tf the TensorFlow Ops - */ - public Sigmoid(Ops tf) { - super(tf); + /** Creates a Sigmoid activation. */ + public Sigmoid() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.math.sigmoid(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java index 154e1ecc84a..3aa67a179ad 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java @@ -38,7 +38,7 @@ * * @param the data type of the activation */ -public class Softmax extends Activation { +public class Softmax extends AbstractActivation { private static final int AXIS_DEFAULT = -1; @@ -47,32 +47,24 @@ public class Softmax extends Activation { /** * Creates a softmax activation where the default axis is {@link #AXIS_DEFAULT} which indicates * the last dimension. - * - * @param tf the TensorFlow Ops */ - public Softmax(Ops tf) { - this(tf, AXIS_DEFAULT); + public Softmax() { + this(AXIS_DEFAULT); } /** * Creates a Softmax activation * - * @param tf the TensorFlow Ops * @param axis The dimension softmax would be performed on. */ - public Softmax(Ops tf, int axis) { - super(tf); + public Softmax(int axis) { + super(); this.axis = axis; } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { Shape shape = input.shape(); int numDimensions = shape.numDimensions(); if (numDimensions == 2) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java index 65a183ea047..8533de7852c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java @@ -32,25 +32,16 @@ * // 1.3132616e+00f, 2.0000000e+01f] * */ -public class Softplus extends Activation { +public class Softplus extends AbstractActivation { - /** - * Creates a Softplus activation function. - * - * @param tf the TensorFlow Ops - */ - public Softplus(Ops tf) { - super(tf); + /** Creates a Softplus activation function. */ + public Softplus() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.math.softplus(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java index 1f691e71862..249fa6077cd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java @@ -33,25 +33,16 @@ * * @param the data type of the activation */ -public class Softsign extends Activation { +public class Softsign extends AbstractActivation { - /** - * Creates a Softsign activation. - * - * @param tf the TensorFlow Ops - */ - public Softsign(Ops tf) { - super(tf); + /** Creates a Softsign activation. */ + public Softsign() { + super(); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.nn.softsign(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java index d9f73a422d5..5007dd34555 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java @@ -40,7 +40,7 @@ * @param the data type of the activation * @see Ramachandran et al., 2017 */ -public class Swish extends Activation { +public class Swish extends AbstractActivation { /** * Creates a Swish activation, swish(x) = x * sigmoid(x). @@ -48,17 +48,14 @@ public class Swish extends Activation { *

      Swish activation function which returns x*sigmoid(x). It is a smooth, * non-monotonic function that consistently matches or outperforms ReLU on deep networks, it is * unbounded above and bounded below. - * - * @param tf the TensorFlow Ops */ - public Swish(Ops tf) { - super(tf); + public Swish() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand input) { - + public Operand call(Ops tf, Operand input) { // TODO Python Keras returns a "grad", which is an optimization not implemented in Java. return tf.math.mul(input, tf.math.sigmoid(input)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java index 4fe02eed048..37d4d811a0d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java @@ -33,20 +33,16 @@ * * @param the data type of the activation */ -public class Tanh extends Activation { +public class Tanh extends AbstractActivation { - /** - * Creates a Hyperbolic tangent activation. - * - * @param tf the TensorFlow Ops - */ - public Tanh(Ops tf) { - super(tf); + /** Creates a Hyperbolic tangent activation. */ + public Tanh() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Ops tf, Operand input) { return tf.math.tanh(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java new file mode 100644 index 00000000000..266d01620bd --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** Base class for Constraints. AbstractConstraint subclasses impose constraints on weight values */ +public abstract class AbstractConstraint implements Constraint { + + public static final float EPSILON = 1e-7f; + + /** Creates a AbstractConstraint */ + public AbstractConstraint() {} + + /** + * Gets the element-wise square root. + * + * @param tf the TensorFlow Ops + * @param x the input Operand. + * @return the element-wise square root. + * @param The data type for the operand and result. + * @throws IllegalArgumentException if x is null + */ + protected Operand sqrt(Ops tf, Operand x) { + if (x == null) throw new IllegalArgumentException("Operand x must not be null"); + Class type = x.type(); + Operand zero = cast(tf, tf.constant(0), type); + Operand inf = cast(tf, tf.constant(Double.POSITIVE_INFINITY), type); + return tf.math.sqrt(tf.clipByValue(x, zero, inf)); + } + + /** + * Gets the element-wise value clipping. + * + * @param tf the TensorFlow Ops + * @param x the Operand to clip + * @param minValue the minimum value + * @param maxValue the maximum value + * @return the operand with clipped values + * @param The data type for the operand and result. + * @throws IllegalArgumentException if x is null + */ + protected Operand clip( + Ops tf, Operand x, double minValue, double maxValue) { + if (x == null) throw new IllegalArgumentException("Operand x must not be null"); + Class type = x.type(); + + double min = Math.min(minValue, maxValue); + double max = Math.max(minValue, maxValue); + + Operand minValueConstant = cast(tf, tf.constant(min), type); + Operand maxValueConstant = cast(tf, tf.constant(max), type); + return tf.clipByValue(x, minValueConstant, maxValueConstant); + } + + /** + * Calculates the norm of the weights along the axes + * + * @param tf the TensorFlow Ops + * @param weights the weights used to calculate the norms + * @param axes the axes along which to calculate weight norms. + * @param the data type for the weights and the result + * @return the norms + * @throws IllegalArgumentException if weights is null + */ + protected Operand norm(Ops tf, Operand weights, int[] axes) { + if (weights == null) throw new IllegalArgumentException("weights must not be null"); + return sqrt( + tf, + tf.reduceSum(tf.math.square(weights), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE))); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java index 306361959bf..97640b19cf8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,96 +16,16 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - -/** Base class for Constraints. Constraint subclasses impose constraints on weight values */ -public abstract class Constraint { - - public static final float EPSILON = 1e-7f; - - private final Ops tf; - - /** - * Creates a Constraint - * - * @param tf the TensorFlow Ops - */ - public Constraint(Ops tf) { - this.tf = tf; - } - +public interface Constraint { /** * Applies the constraint against the provided weights * + * @param tf the TensorFlow Ops * @param weights the weights * @return the constrained weights * @param the data type for weights and results. */ - public abstract Operand call(Operand weights); - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - public Ops getTF() { - return tf; - } - - /** - * Gets the element-wise square root. - * - * @param x the input Operand. - * @return the element-wise square root. - * @param The data type for the operand and result. - * @throws IllegalArgumentException if x is null - */ - protected Operand sqrt(Operand x) { - if (x == null) throw new IllegalArgumentException("Operand x must not be null"); - Class type = x.type(); - Operand zero = cast(tf, tf.constant(0), type); - Operand inf = cast(tf, tf.constant(Double.POSITIVE_INFINITY), type); - return tf.math.sqrt(tf.clipByValue(x, zero, inf)); - } - - /** - * Gets the element-wise value clipping. - * - * @param x the Operand to clip - * @param minValue the minimum value - * @param maxValue the maximum value - * @return the operand with clipped values - * @param The data type for the operand and result. - * @throws IllegalArgumentException if x is null - */ - protected Operand clip(Operand x, double minValue, double maxValue) { - if (x == null) throw new IllegalArgumentException("Operand x must not be null"); - Ops tf = getTF(); - Class type = x.type(); - - double min = Math.min(minValue, maxValue); - double max = Math.max(minValue, maxValue); - - Operand minValueConstant = cast(tf, tf.constant(min), type); - Operand maxValueConstant = cast(tf, tf.constant(max), type); - return tf.clipByValue(x, minValueConstant, maxValueConstant); - } - - /** - * Calculates the norm of the weights along the axes - * - * @param weights the weights used to calculate the norms - * @param axes the axes along which to calculate weight norms. - * @param the data type for the weights and the result - * @return the norms - * @throws IllegalArgumentException if weights is null - */ - protected Operand norm(Operand weights, int[] axes) { - if (weights == null) throw new IllegalArgumentException("weights must not be null"); - return sqrt( - tf.reduceSum(tf.math.square(weights), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE))); - } + Operand call(Ops tf, Operand weights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java index 1dae117b113..b9f082f54de 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java @@ -24,7 +24,7 @@ * Constrains the weights incident to each hidden unit to have a norm less than or equal to a * desired value. */ -public class MaxNorm extends Constraint { +public class MaxNorm extends AbstractConstraint { public static final double MAX_VALUE_DEFAULT = 2.0; public static final int AXIS_DEFAULT = 0; @@ -36,54 +36,48 @@ public class MaxNorm extends Constraint { /** * Create a MaxNorm constraint using {@link #MAX_VALUE_DEFAULT} for the max value and {@link * #AXIS_DEFAULT} for the axis. - * - * @param tf the TensorFlow Ops */ - public MaxNorm(Ops tf) { - this(tf, MAX_VALUE_DEFAULT, AXIS_DEFAULT); + public MaxNorm() { + this(MAX_VALUE_DEFAULT, AXIS_DEFAULT); } /** * Create a MaxNorm constraint using {@link #AXIS_DEFAULT} for the axis. * - * @param tf the TensorFlow Ops * @param maxValue the maximum norm for the incoming weights. */ - public MaxNorm(Ops tf, double maxValue) { - this(tf, maxValue, AXIS_DEFAULT); + public MaxNorm(double maxValue) { + this(maxValue, AXIS_DEFAULT); } /** * Create a MaxNorm constraint * - * @param tf the TensorFlow Ops * @param maxValue the maximum norm for the incoming weights. * @param axis axis along which to calculate weight norms. */ - public MaxNorm(Ops tf, double maxValue, int axis) { - this(tf, maxValue, new int[] {axis}); + public MaxNorm(double maxValue, int axis) { + this(maxValue, new int[] {axis}); } /** * Create a MaxNorm constraint * - * @param tf the TensorFlow Ops * @param maxValue the maximum norm for the incoming weights. * @param axes axes along which to calculate weight norms. */ - public MaxNorm(Ops tf, double maxValue, int[] axes) { - super(tf); + public MaxNorm(double maxValue, int[] axes) { + super(); this.maxValue = maxValue; this.axes = axes; } /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { - Ops tf = getTF(); + public Operand call(Ops tf, Operand weights) { Class type = weights.type(); - Operand norms = norm(weights, getAxes()); - Operand desired = clip(norms, 0f, this.getMaxValue()); + Operand norms = norm(tf, weights, getAxes()); + Operand desired = clip(tf, norms, 0f, this.getMaxValue()); return tf.math.mul( weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms))); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java index 04b21572e55..97e86d7693f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java @@ -21,7 +21,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** Constrains the weights to have the norm between a lower bound and an upper bound. */ -public class MinMaxNorm extends Constraint { +public class MinMaxNorm extends AbstractConstraint { public static final double MIN_VALUE_DEFAULT = 0.0; public static final double MAX_VALUE_DEFAULT = 1.0; public static final double RATE_DEFAULT = 1.0; @@ -47,48 +47,43 @@ public class MinMaxNorm extends Constraint { * Create a MinMaxNorm constraint using {@link #MIN_VALUE_DEFAULT} for the min value, {@link * #MAX_VALUE_DEFAULT} for the max value, {@link #RATE_DEFAULT} for the rate and {@link * #AXIS_DEFAULT} for the axis - * - * @param tf the TensorFlow Ops */ - public MinMaxNorm(Ops tf) { - this(tf, MIN_VALUE_DEFAULT, MAX_VALUE_DEFAULT, RATE_DEFAULT, AXIS_DEFAULT); + public MinMaxNorm() { + this(MIN_VALUE_DEFAULT, MAX_VALUE_DEFAULT, RATE_DEFAULT, AXIS_DEFAULT); } /** * Create a MinMaxNorm constraint using {@link #RATE_DEFAULT} for the rate and {@link * #AXIS_DEFAULT} for the axis * - * @param tf the TensorFlow Ops * @param minValue the minimum norm for the incoming weights. * @param maxValue the maximum norm for the incoming weights. */ - public MinMaxNorm(Ops tf, double minValue, double maxValue) { - this(tf, minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT); + public MinMaxNorm(double minValue, double maxValue) { + this(minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT); } /** * Create a MinMaxNorm constraint * - * @param tf the TensorFlow Ops * @param minValue the minimum norm for the incoming weights. * @param maxValue the maximum norm for the incoming weights. * @param rate the rate for enforcing the constraint. * @param axis integer, axis along which to calculate weight norms. */ - public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int axis) { - this(tf, minValue, maxValue, rate, new int[] {axis}); + public MinMaxNorm(double minValue, double maxValue, double rate, int axis) { + this(minValue, maxValue, rate, new int[] {axis}); } /** * Create a MinMaxNorm constraint * - * @param tf the TensorFlow Ops * @param minValue the minimum norm for the incoming weights. * @param maxValue the maximum norm for the incoming weights. * @param rate the rate for enforcing the constraint. * @param axes integer, axis along which to calculate weight norms. */ - public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int[] axes) { - super(tf); + public MinMaxNorm(double minValue, double maxValue, double rate, int[] axes) { + super(); this.minValue = minValue; this.maxValue = maxValue; this.rate = rate; @@ -97,15 +92,14 @@ public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int[] a /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { + public Operand call(Ops tf, Operand weights) { Class type = weights.type(); - Ops tf = getTF(); - Operand norms = norm(weights, getAxes()); + Operand norms = norm(tf, weights, getAxes()); Operand desired = tf.math.add( tf.math.mul( tf.dtypes.cast(tf.constant(this.getRate()), type), - clip(norms, this.getMinValue(), this.getMaxValue())), + clip(tf, norms, this.getMinValue(), this.getMaxValue())), tf.math.mul( tf.math.sub( tf.dtypes.cast(tf.constant(1), type), diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java index 0194b2fadb6..6a5677983fa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java @@ -19,21 +19,16 @@ import org.tensorflow.types.family.TNumber; /** Constrains the weights to be non-negative. */ -public class NonNeg extends Constraint { +public class NonNeg extends AbstractConstraint { - /** - * Create a NonNeg constraint - * - * @param tf the TensorFlow Ops - */ - public NonNeg(Ops tf) { - super(tf); + /** Create a NonNeg constraint */ + public NonNeg() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { - Ops tf = getTF(); + public Operand call(Ops tf, Operand weights) { Class type = weights.type(); return tf.math.mul( weights, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java index 70bb1a59785..fdd71945229 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java @@ -21,50 +21,43 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** Constrains the weights to have unit norm. */ -public class UnitNorm extends Constraint { +public class UnitNorm extends AbstractConstraint { public static final int AXIS_DEFAULT = 0; /** integer, axis along which to calculate weight norms. */ private final int[] axes; - /** - * Create a UnitNorm Constraint with the axis set to {@link #AXIS_DEFAULT} - * - * @param tf the TensorFlow Ops - */ - public UnitNorm(Ops tf) { - this(tf, AXIS_DEFAULT); + /** Create a UnitNorm AbstractConstraint with the axis set to {@link #AXIS_DEFAULT} */ + public UnitNorm() { + this(AXIS_DEFAULT); } /** - * Create a UnitNorm Constraint + * Create a UnitNorm AbstractConstraint * - * @param tf the TensorFlow Ops * @param axis axis along which to calculate weight norms. */ - public UnitNorm(Ops tf, int axis) { - this(tf, new int[] {axis}); + public UnitNorm(int axis) { + this(new int[] {axis}); } /** - * Create a UnitNorm Constraint + * Create a UnitNorm AbstractConstraint * - * @param tf the TensorFlow Ops * @param axes axes along which to calculate weight norms. */ - public UnitNorm(Ops tf, int[] axes) { - super(tf); + public UnitNorm(int[] axes) { + super(); this.axes = axes; } /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { + public Operand call(Ops tf, Operand weights) { Class type = weights.type(); - Ops tf = getTF(); return tf.math.div( - weights, tf.math.add(cast(tf, tf.constant(EPSILON), type), norm(weights, getAxes()))); + weights, tf.math.add(cast(tf, tf.constant(EPSILON), type), norm(tf, weights, getAxes()))); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java index 9c1fa9ac287..56e3d310280 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java @@ -14,29 +14,24 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TType; /** Abstract base class for all Initializers */ public abstract class BaseInitializer implements Initializer { - protected final Ops tf; + private final String name; - /** - * Creates an Initializer - * - * @param tf the TensorFlow Ops - */ - protected BaseInitializer(Ops tf) { - this.tf = tf; + /** Creates an Initializer */ + protected BaseInitializer() { + name = getClass().getSimpleName(); } /** - * Gets the TensorFlow Ops + * Gets the name for this initializer * - * @return the TensorFlow Ops + * @return the name for this initializer */ - public Ops getTF() { - return tf; + public String getName() { + return name; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java index 4a2df86d74b..508fb69fd55 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java @@ -21,6 +21,8 @@ import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors with a constant value. * @@ -30,7 +32,7 @@ * Constant<TFloat32> initializer = * new org.tensorflow.framework.initializers.Constant<>(tf, 3f); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The Type for the call operation @@ -45,11 +47,10 @@ public class Constant extends BaseInitializer { /** * Creates an Initializer that generates tensors with a constant value. * - * @param tf the TensorFlow Ops * @param value a long value used for the constant. */ - public Constant(Ops tf, long value) { - super(tf); + public Constant(long value) { + super(); longValue = value; doubleValue = 0; booleanValue = false; @@ -59,11 +60,10 @@ public Constant(Ops tf, long value) { /** * Creates an Initializer that generates tensors with a constant value. * - * @param tf the TensorFlow Ops * @param value a double value used for the constant. */ - public Constant(Ops tf, double value) { - super(tf); + public Constant(double value) { + super(); doubleValue = value; longValue = 0; booleanValue = false; @@ -73,11 +73,10 @@ public Constant(Ops tf, double value) { /** * Creates an Initializer that generates tensors with a constant value. * - * @param tf the TensorFlow Ops * @param value a boolean value used for the constant. */ - public Constant(Ops tf, boolean value) { - super(tf); + public Constant(boolean value) { + super(); booleanValue = value; doubleValue = 0; longValue = 0; @@ -86,17 +85,19 @@ public Constant(Ops tf, boolean value) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { - throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName()); + throw new IllegalArgumentException( + "Tensor type must be numeric or boolean: " + type.getSimpleName()); } switch (valueType) { case LONG: - return tf.fill(dims, tf.dtypes.cast(tf.constant(longValue), type)); + return tf.fill(dims, cast(tf, tf.constant(longValue), type)); case DOUBLE: - return tf.fill(dims, tf.dtypes.cast(tf.constant(doubleValue), type)); + return tf.fill(dims, cast(tf, tf.constant(doubleValue), type)); default: - return tf.fill(dims, tf.dtypes.cast(tf.constant(booleanValue), type)); + return tf.fill(dims, cast(tf, tf.constant(booleanValue), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java index 894bd073758..4a39c3839f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.initializers; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; /** @@ -43,7 +42,7 @@ * new org.tensorflow.framework.initializers.Glorot<>(tf, * Distribution.TRUNCATED_NORMAL, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

      Glorot Uniform: @@ -54,12 +53,14 @@ * new org.tensorflow.framework.initializers.Glorot<>(tf, * Distribution.UNIFORM, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

      NOTE: + * *

      For a GlorotNormal equivalent initializer, use {@link * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. + * *

      For a GlorotUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. * @@ -74,13 +75,12 @@ public class Glorot extends VarianceScaling { /** * Creates a Glorot initializer * - * @param tf the TensorFlow Ops * @param distribution The distribution type for the Glorot initializer. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. * @see VarianceScaling.Distribution */ - public Glorot(Ops tf, Distribution distribution, long seed) { - super(tf, SCALE, Mode.FAN_AVG, distribution, seed); + public Glorot(Distribution distribution, long seed) { + super(SCALE, Mode.FAN_AVG, distribution, seed); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java index 3a91b72b0d0..4a9fa8a7849 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; /** @@ -38,7 +37,7 @@ * new org.tensorflow.framework.initializers.He<>(tf, * Distribution.TRUNCATED_NORMAL, seed);); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

      He Uniform: @@ -49,14 +48,16 @@ * new org.tensorflow.framework.initializers.He<>(tf, * Distribution.UNIFORM, seed);); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

      NOTE: + * *

      For an HeNormal equivalent initializer, use {@link * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. - *

      For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} - * for the distribution parameter. + * + *

      For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} for + * the distribution parameter. * * @param The TType for the call operation * @see extends VarianceScaling { /** * Creates an He Initializer * - * @param tf the TensorFlow Ops * @param distribution The distribution type for the He initializer. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. * @see VarianceScaling.Distribution */ - public He(Ops tf, Distribution distribution, long seed) { - super(tf, SCALE, Mode.FAN_IN, distribution, seed); + public He(Distribution distribution, long seed) { + super(SCALE, Mode.FAN_IN, distribution, seed); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java index f672c9f1e85..34a77520406 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java @@ -21,6 +21,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates the identity matrix. * @@ -32,40 +34,34 @@ * Identity<TFloat32> initializer = * new org.tensorflow.framework.initializers.Identity<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation */ public class Identity extends BaseInitializer { public static final double GAIN_DEFAULT = 1.0; - private final double gain; - /** - * Creates an Initializer that generates the identity matrix. - * - * @param tf the TensorFlow Ops - */ - public Identity(Ops tf) { - super(tf); - this.gain = GAIN_DEFAULT; + /** Creates an Initializer that generates the identity matrix. */ + public Identity() { + this(GAIN_DEFAULT); } /** * Creates an Initializer that generates the identity matrix. * - * @param tf the TensorFlow Ops * @param gain the gain to be applied to the Identity Matrix */ - public Identity(Ops tf, double gain) { - super(tf); + public Identity(double gain) { + super(); this.gain = gain; } /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + Shape shape = ShapeUtils.toShape(tf.scope(), dims); if (shape.numDimensions() != 2) { throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions()); @@ -75,9 +71,9 @@ public Operand call(Operand dims, Class type) { Shape diagShape = Shape.of(diagSize); Operand op; - Operand zero = tf.dtypes.cast(tf.constant(0), type); + Operand zero = cast(tf, tf.constant(0), type); Operand diagOnes = - tf.fill(tf.constant(diagShape.asArray()), tf.dtypes.cast(tf.constant(1.0), type)); + tf.fill(tf.constant(diagShape.asArray()), cast(tf, tf.constant(1.0), type)); if (isSquare) { op = tf.linalg.matrixDiag( @@ -91,6 +87,6 @@ public Operand call(Operand dims, Class type) { op = tf.linalg.matrixSetDiag(zeroMatrix, diagOnes, tf.constant(0)); } - return tf.math.mul(op, tf.dtypes.cast(tf.constant(gain), type)); + return tf.math.mul(op, cast(tf, tf.constant(gain), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java index 4beb218783b..d6593b770e2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.initializers; import org.tensorflow.Operand; +import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -23,14 +24,18 @@ * * @param The data Type for initializer operation */ +@FunctionalInterface public interface Initializer { /** * Generates the operation used to perform the initialization. * + * @param tf the TensorFlow Ops * @param dims the shape dimensions * @param type the type of tensor + * @throws IllegalStateException if the object has not been initialized with the TensorFlow + * Platform. * @return An operand for the initialization. */ - Operand call(Operand dims, Class type); + Operand call(Ops tf, Operand dims, Class type); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java index 38e68ef688b..364c5fb9285 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; /** @@ -27,7 +26,7 @@ * stddev = sqrt(1 / fanIn) where fanIn is the number of input units in the * weight tensor. * - *

      If the distribution is UNIFORM, itraws samples from a uniform distribution within + *

      If the distribution is UNIFORM, it draws samples from a uniform distribution within * [-limit, limit], where limit = Math.sqrt(3 / fanIn) (fanIn is * the number of input units in the weight tensor) * @@ -41,7 +40,7 @@ * new org.tensorflow.framework.initializers.LeCunNormal<>(tf, * Distribution.TRUNCATED_NORMAL, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * *

      LeCun Uniform: @@ -52,14 +51,15 @@ * new org.tensorflow.framework.initializers.LeCunNormal<>(tf, * Distribution.UNIFORM, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * * * *

      NOTE: * * - *

      For a LeCunNormal equivalent initializer, use {@link VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. * + *

      For a LeCunNormal equivalent initializer, use {@link + * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. * * *

      For a LeCunUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * * for the distribution parameter. * @@ -79,12 +79,11 @@ public class LeCun extends VarianceScaling { /** * Creates a LeCunNormal Initializer * - * @param tf the TensorFlow Ops * @param distribution The distribution type for the Glorot initializer. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public LeCun(Ops tf, Distribution distribution, long seed) { - super(tf, 1.0, Mode.FAN_IN, distribution, seed); + public LeCun(Distribution distribution, long seed) { + super(1.0, Mode.FAN_IN, distribution, seed); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java index b8eb0c418e9..6e818d30bd7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java @@ -21,6 +21,8 @@ import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors initialized to 1. * @@ -30,7 +32,7 @@ * Ones<TFloat32> initializer = * new org.tensorflow.framework.initializers.Ones<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -46,21 +48,21 @@ public class Ones extends BaseInitializer { * Ones<TFloat32> initializer = * new org.tensorflow.framework.initializers.Ones<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param tf the TensorFlow Ops */ - public Ones(Ops tf) { - super(tf); + public Ones() { + super(); } /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { - throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName()); + throw new IllegalArgumentException( + "Tensor type must be numeric or boolean: " + type.getSimpleName()); } - return tf.fill(dims, tf.dtypes.cast(tf.constant(1.0), type)); + return tf.fill(dims, cast(tf, tf.constant(1), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java index a5b466e118e..519d0cd042e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java @@ -23,6 +23,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates an orthogonal matrix. * @@ -42,7 +44,7 @@ * Orthogonal<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.Orthogonal<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -57,31 +59,30 @@ public class Orthogonal extends BaseInitializer { /** * Creates an Orthogonal Initializer using {@link #GAIN_DEFAULT} for the gain. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public Orthogonal(Ops tf, long seed) { - this(tf, GAIN_DEFAULT, seed); + public Orthogonal(long seed) { + this(GAIN_DEFAULT, seed); } /** * Creates an Orthogonal Initializer * - * @param tf the TensorFlow Ops * @param gain the gain to be applied to the Matrix. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public Orthogonal(Ops tf, double gain, long seed) { - super(tf); + public Orthogonal(double gain, long seed) { + super(); this.gain = gain; this.seed = seed; } /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + Shape dimsShape = ShapeUtils.toShape(tf.scope(), dims); if (dimsShape.numDimensions() < 2) { throw new IllegalArgumentException( @@ -101,10 +102,10 @@ public Operand call(Operand dims, Class type) { Output qo = qrOp.q(); Output ro = qrOp.r(); Operand diagOp = - tf.linalg.matrixDiagPart(ro, tf.constant(0), tf.dtypes.cast(tf.constant(0), type)); + tf.linalg.matrixDiagPart(ro, tf.constant(0), cast(tf, tf.constant(0), type)); Operand qop = tf.math.mul(qo, tf.math.sign(diagOp)); if (numRows < numCols) qop = tf.linalg.transpose(qop, null); - return tf.math.mul(qop, tf.dtypes.cast(tf.constant(this.gain), type)); + return tf.math.mul(qop, cast(tf, tf.constant(this.gain), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java index 38ab194a56b..9a52a641416 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java @@ -19,6 +19,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors with a normal distribution. * @@ -29,7 +31,7 @@ * RandomNormal<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.RandomNormal<>(tf, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -47,37 +49,34 @@ public class RandomNormal extends BaseInitializer { * Creates the RandomUniform initializer using {@link #MEAN_DEFAULT} for the mean and {@link * #STDDEV_DEFAULT} for the standard deviation. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomNormal(Ops tf, long seed) { - this(tf, MEAN_DEFAULT, STDDEV_DEFAULT, seed); + public RandomNormal(long seed) { + this(MEAN_DEFAULT, STDDEV_DEFAULT, seed); } /** * Creates the RandomUniform initializer using {@link #STDDEV_DEFAULT} for the standard deviation. * - * @param tf the TensorFlow Ops * @param mean Mean of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomNormal(Ops tf, double mean, long seed) { - this(tf, mean, STDDEV_DEFAULT, seed); + public RandomNormal(double mean, long seed) { + this(mean, STDDEV_DEFAULT, seed); } /** * creates the RandomUniform initializer * - * @param tf the TensorFlow Ops * @param mean Mean of the random values to generate. * @param stddev Standard deviation of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomNormal(Ops tf, double mean, double stddev, long seed) { - super(tf); + public RandomNormal(double mean, double stddev, long seed) { + super(); this.mean = mean; this.stddev = stddev; this.seed = seed; @@ -85,10 +84,11 @@ public RandomNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + long[] seeds = {seed, 0}; Operand distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); - Operand op = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.stddev), type)); - return tf.math.add(op, tf.dtypes.cast(tf.constant(mean), type)); + Operand op = tf.math.mul(distOp, cast(tf, tf.constant(this.stddev), type)); + return tf.math.add(op, cast(tf, tf.constant(mean), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java index 787af15f709..7288024f5b8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java @@ -21,6 +21,8 @@ import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors with a uniform distribution. * @@ -31,7 +33,7 @@ * RandomUniform<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.RandomUniform<>(tf, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -46,28 +48,26 @@ public class RandomUniform extends BaseInitializer { private final long seed; /** - * Creates a RandomUniform initializer using {@link #MINVAL_DEFAULT} for the minval and - * {@link #MAXVAL_DEFAULT} for the maxval + * Creates a RandomUniform initializer using {@link #MINVAL_DEFAULT} for the minval and {@link + * #MAXVAL_DEFAULT} for the maxval * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomUniform(Ops tf, long seed) { - this(tf, MINVAL_DEFAULT, MAXVAL_DEFAULT, seed); + public RandomUniform(long seed) { + this(MINVAL_DEFAULT, MAXVAL_DEFAULT, seed); } /** * Creates a RandomUniform initializer * - * @param tf the TensorFlow Ops * @param minval Lower bound of the range of random values to generate (inclusive). * @param maxval Upper bound of the range of random values to generate (exclusive). * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public RandomUniform(Ops tf, double minval, double maxval, long seed) { - super(tf); + public RandomUniform(double minval, double maxval, long seed) { + super(); this.minval = minval; this.maxval = maxval; this.seed = seed; @@ -75,26 +75,27 @@ public RandomUniform(Ops tf, double minval, double maxval, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Ops tf, Operand dims, Class type) { + Operand distOp; if (TIntegral.class.isAssignableFrom(type)) { RandomUniformInt.Options options = RandomUniformInt.seed(this.seed); distOp = tf.random.randomUniformInt( dims, - tf.dtypes.cast(tf.constant(this.minval), type), - tf.dtypes.cast(tf.constant(this.maxval), type), + cast(tf, tf.constant(this.minval), type), + cast(tf, tf.constant(this.maxval), type), options); } else { long[] seeds = {seed, 0}; distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); if (this.minval == 0) { if (this.maxval != 1.0) { - distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval), type)); + distOp = tf.math.mul(distOp, cast(tf, tf.constant(this.maxval), type)); } } else { - distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval - this.minval), type)); - distOp = tf.math.add(distOp, tf.dtypes.cast(tf.constant(this.minval), type)); + distOp = tf.math.mul(distOp, cast(tf, tf.constant(this.maxval - this.minval), type)); + distOp = tf.math.add(distOp, cast(tf, tf.constant(this.minval), type)); } } return distOp; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java index d3cfec26338..8069d5d9c7d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java @@ -19,6 +19,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates a truncated normal distribution. * @@ -29,7 +31,7 @@ * TruncatedNormal<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.TruncatedNormal<>(tf, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -47,25 +49,23 @@ public class TruncatedNormal extends BaseInitializer { * Creates a TruncatedNormal Initializer using {@link #MEAN_DEFAULT} for the mean and {@link * #STDDEV_DEFAULT} for the standard deviation. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public TruncatedNormal(Ops tf, long seed) { - this(tf, MEAN_DEFAULT, STDDEV_DEFAULT, seed); + public TruncatedNormal(long seed) { + this(MEAN_DEFAULT, STDDEV_DEFAULT, seed); } /** * Creates a TruncatedNormal Initializer. * - * @param tf the TensorFlow Ops * @param mean Mean of the random values to generate. * @param stddev Standard deviation of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and dtype. */ - public TruncatedNormal(Ops tf, double mean, double stddev, long seed) { - super(tf); + public TruncatedNormal(double mean, double stddev, long seed) { + super(); this.mean = mean; this.stddev = stddev; this.seed = seed; @@ -73,11 +73,12 @@ public TruncatedNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { - long[] seeds = {seed,0}; + public Operand call(Ops tf, Operand dims, Class type) { + + long[] seeds = {seed, 0}; Operand distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); return tf.math.add( - tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)), - tf.dtypes.cast(tf.constant(mean), type)); + tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)), + cast(tf, tf.constant(mean), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java index 5d951450505..a04e4a9a378 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java @@ -21,11 +21,13 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer capable of adapting its scale to the shape of weights tensors. * - *

      With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from - * a truncated/untruncated normal distribution with a mean of zero and a standard deviation (after + *

      With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from a + * truncated/untruncated normal distribution with a mean of zero and a standard deviation (after * truncation, if used) stddev = Math.sqrt(scale / n), where n is: * *

        @@ -46,7 +48,7 @@ * new org.tensorflow.framework.initializers.VarianceScaling<>( * tf, scale, Mode.FAN_IN, Distribution.UNIFORM, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -64,28 +66,25 @@ public class VarianceScaling extends BaseInitializer { private final Distribution distribution; private final long seed; - /** * Creates a VarianceScaling Initializer * - * @param tf the TensorFlow Ops * @param seed sed to create random seeds. */ - public VarianceScaling(Ops tf, long seed) { - this(tf, SCALE_DEFAULT, MODE_DEFAULT, DISTRIBUTION_DEFAULT, seed); + public VarianceScaling(long seed) { + this(SCALE_DEFAULT, MODE_DEFAULT, DISTRIBUTION_DEFAULT, seed); } /** * Creates a VarianceScaling Initializer * - * @param tf the TensorFlow Ops * @param scale Scaling factor (positive float). * @param mode the mode for the variance * @param distribution Random distribution to use. * @param seed Used to create random seeds. */ - public VarianceScaling(Ops tf, double scale, Mode mode, Distribution distribution, long seed) { - super(tf); + public VarianceScaling(double scale, Mode mode, Distribution distribution, long seed) { + super(); if (scale <= 0.0) { throw new IllegalArgumentException("scale must be greater than 0, got " + scale); } @@ -97,8 +96,9 @@ public VarianceScaling(Ops tf, double scale, Mode mode, Distribution distributio /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { - Shape shape = ShapeUtils.toShape(this.tf.scope(), dims); + public Operand call(Ops tf, Operand dims, Class type) { + + Shape shape = ShapeUtils.toShape(tf.scope(), dims); double lscale = this.scale; double[] fans /* fanIn, fanOut */ = computeFans(shape); switch (mode) { @@ -119,18 +119,18 @@ public Operand call(Operand dims, Class type) { switch (distribution) { case TRUNCATED_NORMAL: distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); - stddev = Math.sqrt(lscale) / .87962566103423978; - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); + stddev = Math.sqrt(lscale) / 0.87962566103423978; + mulOp = tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)); break; case NORMAL: distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); stddev = Math.sqrt(lscale); - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); + mulOp = tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)); break; case UNIFORM: distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); stddev = Math.sqrt(3.0 * lscale); - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); + mulOp = tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)); break; } return mulOp; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java index 4298493ac44..f581d247deb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java @@ -28,24 +28,21 @@ * Zeros<TFloat32> initializer = * new org.tensorflow.framework.initializers.Zeros<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); + * initializer.call(Ops tf, tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation */ public class Zeros extends BaseInitializer { - /** - * Creates an Initializer that sets all values to one. - * - * @param tf the TensorFlow Ops - */ - public Zeros(Ops tf) { - super(tf); + /** Creates an Initializer that sets all values to one. */ + public Zeros() { + super(); } @Override - public Operand call(Operand dims, Class dtype) { - return tf.zeros(dims, dtype); + public Operand call(Ops tf, Operand dims, Class type) { + + return tf.zeros(dims, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index 3417c07372a..0c7c6abf8af 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -35,7 +36,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * BinaryCrossentropy bce = new BinaryCrossentropy(tf); - * Operand<TFloat32> result = bce.call(labels, predictions); + * Operand<TFloat32> result = bce.call(Ops tf, labels, predictions); * // produces 0.815 * * @@ -43,7 +44,7 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
        - *    Operand<TFloat32> result = bce.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = bce.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 0.458f
          * 
        * @@ -51,7 +52,7 @@ * *
          *    BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = bce.call(labels, predictions);
        + *    Operand<TFloat32> result = bce.call(Ops tf, labels, predictions);
          *    // produces 1.630f
          * 
        * @@ -59,11 +60,11 @@ * *
          *    BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = bce.call(labels, predictions);
        + *    Operand<TFloat32> result = bce.call(Ops tf, labels, predictions);
          *    // produces [0.916f, 0.714f]
          * 
        */ -public class BinaryCrossentropy extends Loss { +public class BinaryCrossentropy extends AbstractLoss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; @@ -71,70 +72,63 @@ public class BinaryCrossentropy extends Loss { private final float labelSmoothing; /** - * Creates a Binary Crossentropy Loss using {@link Class#getSimpleName()} as the loss name, {@link - * #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing and a - * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Binary Crossentropy AbstractLoss using {@link Class#getSimpleName()} as the loss + * name, {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for + * labelSmoothing and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public BinaryCrossentropy(Ops tf) { - this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); + public BinaryCrossentropy() { + this(null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); } /** * Creates a Binary Crossentropy loss using {@link Class#getSimpleName()} as the loss name, {@link * #FROM_LOGITS_DEFAULT} for fromLogits, and {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public BinaryCrossentropy(Ops tf, Reduction reduction) { - this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction); + public BinaryCrossentropy(Reduction reduction) { + this(null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction); } /** * Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name, * labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT}, a reduction of {@link - * Loss#REDUCTION_DEFAULT}, + * AbstractLoss#REDUCTION_DEFAULT}, * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public BinaryCrossentropy(Ops tf, boolean fromLogits) { - this(tf, null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); + public BinaryCrossentropy(boolean fromLogits) { + this(null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); } /** * Creates a Binary Crossentropy loss using labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT} a - * reduction of {@link Loss#REDUCTION_DEFAULT}. + * reduction of {@link AbstractLoss#REDUCTION_DEFAULT}. * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public BinaryCrossentropy(Ops tf, String name, boolean fromLogits) { - this(tf, name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); + public BinaryCrossentropy(String name, boolean fromLogits) { + this(name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); } /** * Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name, - * and a reduction of {@link Loss#REDUCTION_DEFAULT}. + * and a reduction of {@link AbstractLoss#REDUCTION_DEFAULT}. * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing * correspond to heavier smoothing. */ - public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { - this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT); + public BinaryCrossentropy(boolean fromLogits, float labelSmoothing) { + this(null, fromLogits, labelSmoothing, REDUCTION_DEFAULT); } /** - * Creates a Binary Crossentropy loss using a reduction of {@link Loss#REDUCTION_DEFAULT}. + * Creates a Binary Crossentropy loss using a reduction of {@link AbstractLoss#REDUCTION_DEFAULT}. * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, @@ -142,14 +136,13 @@ public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing * correspond to heavier smoothing. */ - public BinaryCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { - this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT); + public BinaryCrossentropy(String name, boolean fromLogits, float labelSmoothing) { + this(name, fromLogits, labelSmoothing, REDUCTION_DEFAULT); } /** * Creates a Binary Crossentropy loss * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, @@ -157,14 +150,13 @@ public BinaryCrossentropy(Ops tf, String name, boolean fromLogits, float labelSm * correspond to heavier smoothing. * @param reduction Type of Reduction to apply to the loss. */ - public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing, Reduction reduction) { - this(tf, null, fromLogits, labelSmoothing, reduction); + public BinaryCrossentropy(boolean fromLogits, float labelSmoothing, Reduction reduction) { + this(null, fromLogits, labelSmoothing, reduction); } /** * Creates a Binary Crossentropy loss * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, @@ -175,8 +167,8 @@ public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing, Redu * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public BinaryCrossentropy( - Ops tf, String name, boolean fromLogits, float labelSmoothing, Reduction reduction) { - super(tf, name, reduction); + String name, boolean fromLogits, float labelSmoothing, Reduction reduction) { + super(name, reduction); if (labelSmoothing < 0 || labelSmoothing > 1) throw new IllegalArgumentException( "labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); @@ -207,24 +199,25 @@ public BinaryCrossentropy( */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = LossesHelper.rangeCheck( - getTF(), + tf, "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + cast(tf, tf.constant(0), predictions.type()), + cast(tf, tf.constant(1), predictions.type())); } else { lPredictions = predictions; } Operand losses = - Losses.binaryCrossentropy(getTF(), labels, lPredictions, fromLogits, labelSmoothing); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Losses.binaryCrossentropy(tf, labels, lPredictions, fromLogits, labelSmoothing); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 5aac163c1e4..7d65353b004 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -37,7 +38,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.05f, 0.95f, 0f}, {0.1f, 0.8f, 0.1f}}); * CategoricalCrossentropy cce = new CategoricalCrossentropy(tf); - * Operand<TFloat32> result = cce.call(labels, predictions); + * Operand<TFloat32> result = cce.call(Ops tf, labels, predictions); * // produces 1.177 * * @@ -45,15 +46,15 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.3f, 0.7f});
        - *    Operand<TFloat32> result = cce.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = cce.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 0.814f
          * 
        * *

        Using SUM reduction type: * *

        - *    CategoricalCrossentropy cce = new CategoricalCrossentropy(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = cce.call(labels, predictions);
        + *    CategoricalCrossentropy cce = new CategoricalCrossentropy(Reduction.SUM);
        + *    Operand<TFloat32> result = cce.call(Ops tf, labels, predictions);
          *    // produces 2.354f
          * 
        * @@ -61,12 +62,12 @@ * *
          *    CategoricalCrossentropy cce =
        - *        new CategoricalCrossentropy(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = cce.call(labels, predictions);
        + *        new CategoricalCrossentropy(Reduction.NONE);
        + *    Operand<TFloat32> result = cce.call(Ops tf, labels, predictions);
          *    // produces [0.0513f, 2.303f]
          * 
        */ -public class CategoricalCrossentropy extends Loss { +public class CategoricalCrossentropy extends AbstractLoss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; public static final int DEFAULT_AXIS = Losses.CHANNELS_LAST; @@ -76,98 +77,90 @@ public class CategoricalCrossentropy extends Loss { private final int axis; /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for - * labelSmoothing, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and an axis of {@link - * #DEFAULT_AXIS} - * - * @param tf the TensorFlow Ops + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name, {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for + * labelSmoothing, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and an axis + * of {@link #DEFAULT_AXIS} */ - public CategoricalCrossentropy(Ops tf) { - this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy() { + this(null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link #FROM_LOGITS_DEFAULT} for fromLogits, - * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link #FROM_LOGITS_DEFAULT} for + * fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a AbstractLoss Reduction of + * {@link AbstractLoss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this loss */ - public CategoricalCrossentropy(Ops tf, String name) { - this(tf, name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(String name) { + this(name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name, {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for * labelSmoothing and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to loss. */ - public CategoricalCrossentropy(Ops tf, Reduction reduction) { - this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); + public CategoricalCrossentropy(Reduction reduction) { + this(null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link - * #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, and an axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss {@link #FROM_LOGITS_DEFAULT} for fromLogits, + * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this loss * @param reduction Type of Reduction to apply to loss. */ - public CategoricalCrossentropy(Ops tf, String name, Reduction reduction) { - this(tf, name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); + public CategoricalCrossentropy(String name, Reduction reduction) { + this(name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a AbstractLoss Reduction of + * {@link AbstractLoss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public CategoricalCrossentropy(Ops tf, boolean fromLogits) { - this(tf, null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(boolean fromLogits) { + this(null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link #LABEL_SMOOTHING_DEFAULT} for - * labelSmoothing, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and a channel axis of - * {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link #LABEL_SMOOTHING_DEFAULT} for + * labelSmoothing, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and a + * channel axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { - this(tf, name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(String name, boolean fromLogits) { + this(name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and a channel + * axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and * 0.9 for label 1 */ - public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { - this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(boolean fromLogits, float labelSmoothing) { + this(null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are @@ -175,15 +168,14 @@ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) *
        means that we will use a value of 0.1 for label 0 and * 0.9 for label 1 */ - public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { - this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); + public CategoricalCrossentropy(String name, boolean fromLogits, float labelSmoothing) { + this(name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name - * and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy AbstractLoss using {@link Class#getSimpleName()} as the + * loss name and a channel axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are * smoothed, meaning the confidence on label values are relaxed. e.g. x=0.2 means @@ -191,15 +183,13 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la * for label 1 * @param reduction Type of Reduction to apply to loss. */ - public CategoricalCrossentropy( - Ops tf, boolean fromLogits, float labelSmoothing, Reduction reduction) { - this(tf, null, fromLogits, labelSmoothing, reduction, DEFAULT_AXIS); + public CategoricalCrossentropy(boolean fromLogits, float labelSmoothing, Reduction reduction) { + this(null, fromLogits, labelSmoothing, reduction, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss + * Creates a categorical cross entropy AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are @@ -213,13 +203,8 @@ public CategoricalCrossentropy( * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public CategoricalCrossentropy( - Ops tf, - String name, - boolean fromLogits, - float labelSmoothing, - Reduction reduction, - int axis) { - super(tf, name, reduction); + String name, boolean fromLogits, float labelSmoothing, Reduction reduction, int axis) { + super(name, reduction); if (labelSmoothing < 0 || labelSmoothing > 1) throw new IllegalArgumentException( "labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); @@ -251,24 +236,24 @@ public CategoricalCrossentropy( */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = LossesHelper.rangeCheck( - getTF(), + tf, "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + cast(tf, tf.constant(0), predictions.type()), + cast(tf, tf.constant(1), predictions.type())); } else { lPredictions = predictions; } Operand losses = - Losses.categoricalCrossentropy( - getTF(), labels, lPredictions, fromLogits, labelSmoothing, axis); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Losses.categoricalCrossentropy(tf, labels, lPredictions, fromLogits, labelSmoothing, axis); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index 73837ed1756..c9987fb0884 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -35,7 +36,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * CategoricalHinge categoricalHinge = new CategoricalHinge(tf); - * Operand<TFloat32> result = categoricalHinge.call(labels, predictions); + * Operand<TFloat32> result = categoricalHinge.call(Ops tf, labels, predictions); * // produces 1.4 * * @@ -43,7 +44,7 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1f, 0.f});
        - *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = categoricalHinge.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 0.6f
          * 
        * @@ -51,7 +52,7 @@ * *
          *    CategoricalHinge categoricalHinge = new CategoricalHinge(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions);
        + *    Operand<TFloat32> result = categoricalHinge.call(Ops tf, labels, predictions);
          *    // produces 2.8f
          * 
        * @@ -60,48 +61,45 @@ *
          *    CategoricalHinge categoricalHinge =
          *        new CategoricalHinge(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions);
        + *    Operand<TFloat32> result = categoricalHinge.call(Ops tf, labels, predictions);
          *    // produces [1.2f, 1.6f]
          * 
        */ -public class CategoricalHinge extends Loss { +public class CategoricalHinge extends AbstractLoss { /** - * Creates a Categorical Hinge Loss using {@link Class#getSimpleName()} as the loss name and a - * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Categorical Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public CategoricalHinge(Ops tf) { - super(tf); + public CategoricalHinge() { + super(); } /** - * Creates a Categorical Hinge Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Categorical Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public CategoricalHinge(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public CategoricalHinge(Reduction reduction) { + super(null, reduction); } /** * Creates a Categorical Hinge * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public CategoricalHinge(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public CategoricalHinge(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.categoricalHinge(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java index 0a18d93caf3..ac810139d71 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -40,7 +41,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 0.f}, {1.f, 1.f}}); * CosineSimilarity cosineLoss = new CosineSimilarity(tf); - * Operand<TFloat32> result = cosineLoss.call(labels, predictions); + * Operand<TFloat32> result = cosineLoss.call(Ops tf, labels, predictions); * // produces -0.5 * * @@ -48,7 +49,7 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
        - *    Operand<TFloat32> result = cosineLoss.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = cosineLoss.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces -0.0999f
          * 
        * @@ -56,7 +57,7 @@ * *
          *    CosineSimilarity cosineLoss = new CosineSimilarity(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = cosineLoss.call(labels, predictions);
        + *    Operand<TFloat32> result = cosineLoss.call(Ops tf, labels, predictions);
          *    // produces -0.999f
          * 
        * @@ -64,165 +65,155 @@ * *
          *    CosineSimilarity cosineLoss = new CosineSimilarity(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = cosineLoss.call(labels, predictions);
        + *    Operand<TFloat32> result = cosineLoss.call(Ops tf, labels, predictions);
          *    // produces [-0.f, -0.999f]
          * 
        */ -public class CosineSimilarity extends Loss { +public class CosineSimilarity extends AbstractLoss { public static final int DEFAULT_AXIS = -1; public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO; private final int[] axis; /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, an axis - * of {@link #DEFAULT_AXIS}, and a Loss Reduction of {@link #DEFAULT_REDUCTION} - * - * @param tf the TensorFlow Ops + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name, + * an axis of {@link #DEFAULT_AXIS}, and a AbstractLoss Reduction of {@link #DEFAULT_REDUCTION} */ - public CosineSimilarity(Ops tf) { + public CosineSimilarity() { - this(tf, null, DEFAULT_AXIS, DEFAULT_REDUCTION); + this(null, DEFAULT_AXIS, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using an axis of {@link #DEFAULT_AXIS}, and a Loss Reduction - * of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using an axis of {@link #DEFAULT_AXIS}, and a + * AbstractLoss Reduction of {@link #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param name the name of the loss */ - public CosineSimilarity(Ops tf, String name) { + public CosineSimilarity(String name) { - this(tf, name, DEFAULT_AXIS, DEFAULT_REDUCTION); + this(name, DEFAULT_AXIS, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a - * Loss Reduction of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name, + * and a AbstractLoss Reduction of {@link #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param axis The dimension along which the cosine similarity is computed. */ - public CosineSimilarity(Ops tf, int axis) { + public CosineSimilarity(int axis) { - this(tf, null, axis, DEFAULT_REDUCTION); + this(null, axis, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a - * Loss Reduction of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name, + * and a AbstractLoss Reduction of {@link #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param axis The dimension along which the cosine similarity is computed. */ - public CosineSimilarity(Ops tf, int[] axis) { + public CosineSimilarity(int[] axis) { - this(tf, null, axis, DEFAULT_REDUCTION); + this(null, axis, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using a AbstractLoss Reduction of {@link + * #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param axis The dimension along which the cosine similarity is computed. */ - public CosineSimilarity(Ops tf, String name, int axis) { + public CosineSimilarity(String name, int axis) { - this(tf, name, axis, DEFAULT_REDUCTION); + this(name, axis, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} + * Creates a Cosine Similarity AbstractLoss using a AbstractLoss Reduction of {@link + * #DEFAULT_REDUCTION} * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param axis The dimension along which the cosine similarity is computed. */ - public CosineSimilarity(Ops tf, String name, int[] axis) { + public CosineSimilarity(String name, int[] axis) { - this(tf, name, axis, DEFAULT_REDUCTION); + this(name, axis, DEFAULT_REDUCTION); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name and an - * axis of {@link #DEFAULT_AXIS} + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, Reduction reduction) { + public CosineSimilarity(Reduction reduction) { - this(tf, null, DEFAULT_AXIS, reduction); + this(null, DEFAULT_AXIS, reduction); } /** - * Creates a Cosine Similarity Loss using an axis of {@link #DEFAULT_AXIS} + * Creates a Cosine Similarity AbstractLoss using an axis of {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, String name, Reduction reduction) { + public CosineSimilarity(String name, Reduction reduction) { - this(tf, name, DEFAULT_AXIS, reduction); + this(name, DEFAULT_AXIS, reduction); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param axis The dimension along which the cosine similarity is computed. * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, int axis, Reduction reduction) { + public CosineSimilarity(int axis, Reduction reduction) { - this(tf, null, new int[] {axis}, reduction); + this(null, new int[] {axis}, reduction); } /** - * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Cosine Similarity AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param axis The dimension along which the cosine similarity is computed. * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, int[] axis, Reduction reduction) { + public CosineSimilarity(int[] axis, Reduction reduction) { - this(tf, null, axis, reduction); + this(null, axis, reduction); } /** - * Creates a Cosine Similarity Loss + * Creates a Cosine Similarity AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param axis The dimension along which the cosine similarity is computed. * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) { - this(tf, name, new int[] {axis}, reduction); + public CosineSimilarity(String name, int axis, Reduction reduction) { + this(name, new int[] {axis}, reduction); } /** - * Creates a Cosine Similarity Loss + * Creates a Cosine Similarity AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param axis The dimension along which the cosine similarity is computed. * @param reduction Type of Reduction to apply to the loss. */ - public CosineSimilarity(Ops tf, String name, int[] axis, Reduction reduction) { - super(tf, name, reduction); + public CosineSimilarity(String name, int[] axis, Reduction reduction) { + super(name, reduction); this.axis = axis; } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.cosineSimilarity(tf, labels, predictions, axis); losses = tf.math.neg(losses); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index d4c350ef06c..05c5b47e329 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -37,7 +38,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * Hinge hingeLoss = new Hinge(tf); - * Operand<TFloat32> result = hingeLoss.call(labels, predictions); + * Operand<TFloat32> result = hingeLoss.call(Ops tf, labels, predictions); * // produces 1.3f * * @@ -45,57 +46,53 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
        - *    Operand<TFloat32> result = hingeLoss.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = hingeLoss.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 0.55f
          * 
        * *

        Using SUM reduction type: * *

        - *    Hinge hingeLoss = new Hinge(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = hingeLoss.call(labels, predictions);
        + *    Hinge hingeLoss = new Hinge(Reduction.SUM);
        + *    Operand<TFloat32> result = hingeLoss.call(Ops tf, labels, predictions);
          *    // produces 2.6f
          * 
        * *

        Using NONE reduction type: * *

        - *    Hinge hingeLoss = new Hinge(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = hingeLoss.call(labels, predictions);
        + *    Hinge hingeLoss = new Hinge(Reduction.NONE);
        + *    Operand<TFloat32> result = hingeLoss.call(Ops tf, labels, predictions);
          *    // produces [1.1f, 1.5f]
          * 
        */ -public class Hinge extends Loss { +public class Hinge extends AbstractLoss { /** - * Creates a Hinge Loss using {@link Class#getSimpleName()} as the loss name and a Loss Reduction - * of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name and a + * AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public Hinge(Ops tf) { - this(tf, null, Reduction.AUTO); + public Hinge() { + this(null, Reduction.AUTO); } /** - * Creates a Hinge Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public Hinge(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public Hinge(Reduction reduction) { + super(null, reduction); } /** * Creates a Hinge * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public Hinge(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public Hinge(String name, Reduction reduction) { + super(name, reduction); } /** @@ -122,15 +119,16 @@ public Hinge(Ops tf, String name, Reduction reduction) { */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Operand tLabels = cast(tf, labels, predictions.type()); tLabels = LossesHelper.valueCheck( - getTF(), + tf, "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); - Operand losses = Losses.hinge(getTF(), tLabels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + cast(tf, tf.constant(new int[] {-1, 0, 1}), predictions.type())); + Operand losses = Losses.hinge(tf, tLabels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index b1aee1b0656..c9a7d7edcb8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -39,7 +40,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * Huber huberLoss = new Huber(tf); - * Operand<TFloat32> result = huberLoss.call(labels, predictions); + * Operand<TFloat32> result = huberLoss.call(Ops tf, labels, predictions); * // produces 0.155 * * @@ -47,7 +48,7 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
        - *    Operand<TFloat32> result = huberLoss.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = huberLoss.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 0.09f
          * 
        * @@ -55,7 +56,7 @@ * *
          *    Huber huberLoss = new Huber(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = huberLoss.call(labels, predictions);
        + *    Operand<TFloat32> result = huberLoss.call(Ops tf, labels, predictions);
          *    // produces 0.32f
          * 
        * @@ -63,78 +64,74 @@ * *
          *    Huber huberLoss = new Huber(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = huberLoss.call(labels, predictions);
        + *    Operand<TFloat32> result = huberLoss.call(Ops tf, labels, predictions);
          *    // produces [0.18f, 0.13f]
          * 
        * * @see
        Huber loss */ -public class Huber extends Loss { +public class Huber extends AbstractLoss { public static final float DELTA_DEFAULT = 1.0f; private final float delta; /** - * Creates a Huber Loss using {@link Class#getSimpleName()} as the loss name, {@link - * #DELTA_DEFAULT} as the delta and a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Huber AbstractLoss using {@link Class#getSimpleName()} as the loss name, {@link + * #DELTA_DEFAULT} as the delta and a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT} */ - public Huber(Ops tf) { - this(tf, null, DELTA_DEFAULT, Reduction.AUTO); + public Huber() { + this(null, DELTA_DEFAULT, Reduction.AUTO); } /** - * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta and a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT} + * Creates a Huber AbstractLoss using {@link #DELTA_DEFAULT} as the delta and a AbstractLoss + * Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ - public Huber(Ops tf, String name) { - this(tf, name, DELTA_DEFAULT, Reduction.AUTO); + public Huber(String name) { + this(name, DELTA_DEFAULT, Reduction.AUTO); } /** - * Creates a Huber Loss using {@link Class#getSimpleName()} as the loss name and and {@link - * #DELTA_DEFAULT} as the delta + * Creates a Huber AbstractLoss using {@link Class#getSimpleName()} as the loss name and and + * {@link #DELTA_DEFAULT} as the delta * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public Huber(Ops tf, Reduction reduction) { - this(tf, null, DELTA_DEFAULT, reduction); + public Huber(Reduction reduction) { + this(null, DELTA_DEFAULT, reduction); } /** - * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta + * Creates a Huber AbstractLoss using {@link #DELTA_DEFAULT} as the delta * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ - public Huber(Ops tf, String name, Reduction reduction) { - this(tf, name, DELTA_DEFAULT, reduction); + public Huber(String name, Reduction reduction) { + this(name, DELTA_DEFAULT, reduction); } /** - * Creates a Huber Loss + * Creates a Huber AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param delta the point where the Huber loss function changes from quadratic to linear. * @param reduction Type of Reduction to apply to the loss. */ - public Huber(Ops tf, String name, float delta, Reduction reduction) { - super(tf, name, reduction); + public Huber(String name, float delta, Reduction reduction) { + super(name, reduction); this.delta = delta; } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.huber(getTF(), labels, predictions, delta); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.huber(tf, labels, predictions, delta); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 2aa1f72092b..ef5d88539db 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -31,8 +32,8 @@ * tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}}); * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); - * KLDivergence kld = new KLDivergence(tf); - * Operand<TFloat32> result = kld.call(labels, predictions); + * KLDivergence kld = new KLDivergence(); + * Operand<TFloat32> result = kld.call(Ops tf, labels, predictions); * // produces 0.458 * * @@ -40,68 +41,65 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
        - *    Operand<TFloat32> result = kld.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = kld.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 0.366f
          * 
        * *

        Using SUM reduction type: * *

        - *    KLDivergence kld = new KLDivergence(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = kld.call(labels, predictions);
        + *    KLDivergence kld = new KLDivergence(, Reduction.SUM);
        + *    Operand<TFloat32> result = kld.call(Ops tf, labels, predictions);
          *    // produces 0.916f
          * 
        * *

        Using NONE reduction type: * *

        - *    KLDivergence kld = new KLDivergence(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = kld.call(labels, predictions);
        + *    KLDivergence kld = new KLDivergence(, Reduction.NONE);
        + *    Operand<TFloat32> result = kld.call(Ops tf, labels, predictions);
          *    // produces [0.916f, -3.08e-06f]
          * 
        * * @see Kullback?Leibler * divergence */ -public class KLDivergence extends Loss { +public class KLDivergence extends AbstractLoss { /** - * Creates a Kullback Leibler Divergence Loss using {@link Class#getSimpleName()} as the loss name - * and a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Kullback Leibler Divergence AbstractLoss using {@link Class#getSimpleName()} as the + * loss name and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public KLDivergence(Ops tf) { - super(tf); + public KLDivergence() { + super(); } /** - * Creates a Kullback Leibler Divergence Loss Loss using {@link Class#getSimpleName()} as the loss - * name + * Creates a Kullback Leibler Divergence AbstractLoss AbstractLoss using {@link + * Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public KLDivergence(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public KLDivergence(Reduction reduction) { + super(null, reduction); } /** - * Creates a Kullback Leibler Divergence Loss + * Creates a Kullback Leibler Divergence AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public KLDivergence(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public KLDivergence(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.kullbackLeiblerDivergence(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index a11d582e527..02200c3a9e0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -33,7 +34,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {0.f, 0.f}}); * LogCosh logcosh = new LogCosh(tf); - * Operand<TFloat32> result = logcosh.call(labels, predictions); + * Operand<TFloat32> result = logcosh.call(Ops tf, labels, predictions); * // produces 0.108 * * @@ -41,74 +42,71 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
        - *    Operand<TFloat32> result = logcosh.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = logcosh.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 0.087f
          * 
        * *

        Using SUM reduction type: * *

        - *    LogCosh logcosh = new LogCosh(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = logcosh.call(labels, predictions);
        + *    LogCosh logcosh = new LogCosh(Reduction.SUM);
        + *    Operand<TFloat32> result = logcosh.call(Ops tf, labels, predictions);
          *    // produces 0.217f
          * 
        * *

        Using NONE reduction type: * *

        - *    LogCosh logcosh = new LogCosh(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = logcosh.call(labels, predictions);
        + *    LogCosh logcosh = new LogCosh(Reduction.NONE);
        + *    Operand<TFloat32> result = logcosh.call(Ops tf, labels, predictions);
          *    // produces [0.217f, 0f]
          * 
        */ -public class LogCosh extends Loss { +public class LogCosh extends AbstractLoss { /** - * Creates a LogCosh Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a LogCosh AbstractLoss using {@link Class#getSimpleName()} as the loss name and a + * AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public LogCosh(Ops tf) { - this(tf, null, Reduction.AUTO); + public LogCosh() { + this(null, Reduction.AUTO); } /** - * Creates a LogCosh Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * Creates a LogCosh AbstractLoss using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ - public LogCosh(Ops tf, String name) { - this(tf, name, Reduction.AUTO); + public LogCosh(String name) { + this(name, Reduction.AUTO); } /** - * Creates a LogCosh Loss using {@link Class#getSimpleName()} as the loss name + * Creates a LogCosh AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public LogCosh(Ops tf, Reduction reduction) { - this(tf, null, reduction); + public LogCosh(Reduction reduction) { + this(null, reduction); } /** - * Creates a LogCosh Loss + * Creates a LogCosh AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ - public LogCosh(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public LogCosh(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.logCosh(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.logCosh(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index cdd35d28aba..4dd5bce6cde 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,60 +18,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -public abstract class Loss { - public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; - - protected final Ops tf; - protected final Reduction reduction; - - /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops - */ - protected Loss(Ops tf) { - this(tf, null, Reduction.AUTO); - } - - /** - * Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops - * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. - */ - protected Loss(Ops tf, String name) { - this(tf, name, Reduction.AUTO); - } - - /** - * Creates a Loss - * - * @param tf the TensorFlow Ops - * @param name the name of this loss, if null the name will be {@link Class#getSimpleName()}. - * @param reduction Type of Reduction to apply to the loss. - */ - protected Loss(Ops tf, String name, Reduction reduction) { - this.tf = name != null ? tf.withSubScope(name) : tf.withSubScope(getClass().getSimpleName()); - this.reduction = reduction; - } - - /** - * Calculates the loss - * - * @param labels the truth values or labels - * @param predictions the predictions - * @param The data type of the predictions and loss. - * @return the loss - */ - public Operand call( - Operand labels, Operand predictions) { - return call(labels, predictions, null); - } +/** Interface for loss calc ulation */ +@FunctionalInterface +public interface Loss { /** * Generates an Operand that calculates the loss. * + * @param tf the TensorFlow Ops * @param labels the truth values or labels * @param predictions the predictions * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is @@ -84,24 +38,6 @@ public Operand call( * @param The data type of the predictions, sampleWeights and loss. * @return the loss */ - public abstract Operand call( - Operand labels, Operand predictions, Operand sampleWeights); - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - public Ops getTF() { - return tf; - } - - /** - * Gets the loss reduction - * - * @return the loss reduction - */ - public Reduction getReduction() { - return reduction; - } + Operand call( + Ops tf, Operand labels, Operand predictions, Operand sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index 03a3cf70110..d85bdf3561a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}}); * MeanAbsoluteError mae = new MeanAbsoluteError(tf); - * Operand<TFloat32> result = mae.call(labels, predictions); + * Operand<TFloat32> result = mae.call(Ops tf, labels, predictions); * // produces 0.5f * * @@ -40,64 +41,61 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
        - *    Operand<TFloat32> result = mae.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = mae.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 0.25f
          * 
        * *

        Using SUM reduction type: * *

        - *    MeanAbsoluteError mae = new MeanAbsoluteError(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = mae.call(labels, predictions);
        + *    MeanAbsoluteError mae = new MeanAbsoluteError(Reduction.SUM);
        + *    Operand<TFloat32> result = mae.call(Ops tf, labels, predictions);
          *    // produces 1.0f
          * 
        * *

        Using NONE reduction type: * *

        - *    MeanAbsoluteError mae = new MeanAbsoluteError(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = mae.call(labels, predictions);
        + *    MeanAbsoluteError mae = new MeanAbsoluteError(Reduction.NONE);
        + *    Operand<TFloat32> result = mae.call(Ops tf, labels, predictions);
          *    // produces [0.5f, 0.5f]
          * 
        */ -public class MeanAbsoluteError extends Loss { +public class MeanAbsoluteError extends AbstractLoss { /** - * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name and a - * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a MeanAbsoluteError AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public MeanAbsoluteError(Ops tf) { - super(tf); + public MeanAbsoluteError() { + super(); } /** - * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name + * Creates a MeanAbsoluteError AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public MeanAbsoluteError(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public MeanAbsoluteError(Reduction reduction) { + super(null, reduction); } /** * Creates a MeanAbsoluteError * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public MeanAbsoluteError(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.meanAbsoluteError(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 6c5242df4f2..ed5c7d73e2f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}}); * MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(tf); - * Operand<TFloat32> result = mape.call(labels, predictions); + * Operand<TFloat32> result = mape.call(Ops tf, labels, predictions); * // produces 50f * * @@ -40,64 +41,62 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
        - *    Operand<TFloat32> result = mape.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = mape.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 20f
          * 
        * *

        Using SUM reduction type: * *

        - *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = mape.call(labels, predictions);
        + *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(Reduction.SUM);
        + *    Operand<TFloat32> result = mape.call(Ops tf, labels, predictions);
          *    // produces 100.0f
          * 
        * *

        Using NONE reduction type: * *

        - *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = mape.call(labels, predictions);
        + *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(Reduction.NONE);
        + *    Operand<TFloat32> result = mape.call(Ops tf, labels, predictions);
          *    // produces [25f, 75f]
          * 
        */ -public class MeanAbsolutePercentageError extends Loss { +public class MeanAbsolutePercentageError extends AbstractLoss { /** - * Creates a MeanAbsolutePercentageError Loss using {@link Class#getSimpleName()} as the loss name - * and a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a MeanAbsolutePercentageError AbstractLoss using {@link Class#getSimpleName()} as the + * loss name and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public MeanAbsolutePercentageError(Ops tf) { - super(tf); + public MeanAbsolutePercentageError() { + super(); } /** - * Creates a MeanAbsolutePercentageError Loss using {@link Class#getSimpleName()} as the loss name + * Creates a MeanAbsolutePercentageError AbstractLoss using {@link Class#getSimpleName()} as the + * loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public MeanAbsolutePercentageError(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public MeanAbsolutePercentageError(Reduction reduction) { + super(null, reduction); } /** * Creates a MeanAbsolutePercentageError * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public MeanAbsolutePercentageError(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.meanAbsolutePercentageError(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index f975db55c44..c6898e20f20 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}}); * MeanSquaredError mse = new MeanSquaredError(tf); - * Operand<TFloat32> result = mse.call(labels, predictions); + * Operand<TFloat32> result = mse.call(Ops tf, labels, predictions); * // produces 0.5f * * @@ -40,64 +41,61 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
        - *    Operand<TFloat32> result = mse.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = mse.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 0.25f
          * 
        * *

        Using SUM reduction type: * *

        - *    MeanSquaredError mse = new MeanSquaredError(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = mse.call(labels, predictions);
        + *    MeanSquaredError mse = new MeanSquaredError(Reduction.SUM);
        + *    Operand<TFloat32> result = mse.call(Ops tf, labels, predictions);
          *    // produces 1.0f
          * 
        * *

        Using NONE reduction type: * *

        - *    MeanSquaredError mse = new MeanSquaredError(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = mse.call(labels, predictions);
        + *    MeanSquaredError mse = new MeanSquaredError(Reduction.NONE);
        + *    Operand<TFloat32> result = mse.call(Ops tf, labels, predictions);
          *    // produces [0.5f, 0.5f]
          * 
        */ -public class MeanSquaredError extends Loss { +public class MeanSquaredError extends AbstractLoss { /** - * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a MeanSquaredError AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public MeanSquaredError(Ops tf) { - super(tf); + public MeanSquaredError() { + super(); } /** - * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name + * Creates a MeanSquaredError AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public MeanSquaredError(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public MeanSquaredError(Reduction reduction) { + super(null, reduction); } /** * Creates a MeanSquaredError * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public MeanSquaredError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public MeanSquaredError(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.meanSquaredError(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index 11b8e157e90..3d325a98a6a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}}); * MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(tf); - * Operand<TFloat32> result = msle.call(labels, predictions); + * Operand<TFloat32> result = msle.call(Ops tf, labels, predictions); * // produces 0.240f * * @@ -40,64 +41,61 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
        - *    Operand<TFloat32> result = msle.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = msle.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 0.120f
          * 
        * *

        Using SUM reduction type: * *

        - *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = msle.call(labels, predictions);
        + *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(Reduction.SUM);
        + *    Operand<TFloat32> result = msle.call(Ops tf, labels, predictions);
          *    // produces 0.480f
          * 
        * *

        Using NONE reduction type: * *

        - *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = msle.call(labels, predictions);
        + *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(Reduction.NONE);
        + *    Operand<TFloat32> result = msle.call(Ops tf, labels, predictions);
          *    // produces [0.240f, 0.240f]
          * 
        */ -public class MeanSquaredLogarithmicError extends Loss { +public class MeanSquaredLogarithmicError extends AbstractLoss { /** - * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a MeanSquaredError AbstractLoss using {@link Class#getSimpleName()} as the loss name + * and a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public MeanSquaredLogarithmicError(Ops tf) { - super(tf); + public MeanSquaredLogarithmicError() { + super(); } /** - * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name + * Creates a MeanSquaredError AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public MeanSquaredLogarithmicError(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public MeanSquaredLogarithmicError(Reduction reduction) { + super(null, reduction); } /** * Creates a MeanSquaredError * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public MeanSquaredLogarithmicError(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.meanSquaredLogarithmicError(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index 78324acf8a5..a6eb29b7109 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -32,7 +33,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{1.f, 1.f}, {0.f, 0.f}}); * Poisson poissonLoss = new Poisson(tf); - * Operand<TFloat32> result = poissonLoss.call(labels, predictions); + * Operand<TFloat32> result = poissonLoss.call(Ops tf, labels, predictions); * // produces 0.5f * * @@ -40,74 +41,71 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
        - *    Operand<TFloat32> result = poissonLoss.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = poissonLoss.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 0.4f
          * 
        * *

        Using SUM reduction type: * *

        - *    Poisson poissonLoss = new Poisson(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = poissonLoss.call(labels, predictions);
        + *    Poisson poissonLoss = new Poisson(Reduction.SUM);
        + *    Operand<TFloat32> result = poissonLoss.call(Ops tf, labels, predictions);
          *    // produces 0.999f
          * 
        * *

        Using NONE reduction type: * *

        - *    Poisson poissonLoss = new Poisson(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = poissonLoss.call(labels, predictions);
        + *    Poisson poissonLoss = new Poisson(Reduction.NONE);
        + *    Operand<TFloat32> result = poissonLoss.call(Ops tf, labels, predictions);
          *    // produces [0.999f, 0f]
          * 
        */ -public class Poisson extends Loss { +public class Poisson extends AbstractLoss { /** - * Creates a Poisson Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Poisson AbstractLoss using {@link Class#getSimpleName()} as the loss name and a + * AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public Poisson(Ops tf) { - this(tf, null, Reduction.AUTO); + public Poisson() { + this(null, Reduction.AUTO); } /** - * Creates a Poisson Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * Creates a Poisson AbstractLoss using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ - public Poisson(Ops tf, String name) { - this(tf, name, Reduction.AUTO); + public Poisson(String name) { + this(name, Reduction.AUTO); } /** - * Creates a Poisson Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Poisson AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public Poisson(Ops tf, Reduction reduction) { - this(tf, null, reduction); + public Poisson(Reduction reduction) { + this(null, reduction); } /** - * Creates a Poisson Loss + * Creates a Poisson AbstractLoss * - * @param tf the TensorFlow Ops * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ - public Poisson(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public Poisson(String name, Reduction reduction) { + super(name, reduction); } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.poisson(getTF(), labels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + + Operand losses = Losses.poisson(tf, labels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java index 87ea43c6c3a..e40ec6d6ebb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java @@ -15,7 +15,7 @@ package org.tensorflow.framework.losses; /** - * Type of Loss Reduction + * Type of AbstractLoss Reduction * *

        {@link #AUTO} indicates that the reduction option will be determined by the usage context. For * almost all cases this defaults to {@link #SUM_OVER_BATCH_SIZE}. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index d04cc67d5d9..291a91894b0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -43,7 +44,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.05f, 0.95f, 0f}, {0.1f, 0.8f, 0.1f}}); * SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(tf); - * Operand<TFloat32> result = sparseCCE.call(labels, predictions); + * Operand<TFloat32> result = sparseCCE.call(Ops tf, labels, predictions); * // produces 1.177f * * @@ -51,27 +52,27 @@ * *

          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.3f, 0.7f});
        - *    Operand<TFloat32> result = sparseCCE.call(labels, predictions, sampleWeight);
        + *    Operand<TFloat32> result = sparseCCE.call(Ops tf, labels, predictions, sampleWeight);
          *    // produces 0.814f
          * 
        * *

        Using SUM reduction type: * *

        - *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = sparseCCE.call(labels, predictions);
        + *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(Reduction.SUM);
        + *    Operand<TFloat32> result = sparseCCE.call(Ops tf, labels, predictions);
          *    // produces 2.354f
          * 
        * *

        Using NONE reduction type: * *

        - *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = sparseCCE.call(labels, predictions);
        + *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(Reduction.NONE);
        + *    Operand<TFloat32> result = sparseCCE.call(Ops tf, labels, predictions);
          *    // produces [0.0513f, 2.303f]
          * 
        */ -public class SparseCategoricalCrossentropy extends Loss { +public class SparseCategoricalCrossentropy extends AbstractLoss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final int AXIS_DEFAULT = -1; @@ -80,24 +81,23 @@ public class SparseCategoricalCrossentropy extends Loss { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link + * name, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and fromLogits={@link * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops */ - public SparseCategoricalCrossentropy(Ops tf) { - this(tf, null, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); + public SparseCategoricalCrossentropy() { + this(null, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); } /** - * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy loss using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * - * @param tf the TensorFlow Ops * @param name the name of this loss function */ - public SparseCategoricalCrossentropy(Ops tf, String name) { - this(tf, name, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(String name) { + this(name, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); } /** @@ -107,8 +107,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name) { * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to loss. */ - public SparseCategoricalCrossentropy(Ops tf, Reduction reduction) { - this(tf, null, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(Reduction reduction) { + this(null, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); } /** @@ -119,32 +119,32 @@ public SparseCategoricalCrossentropy(Ops tf, Reduction reduction) { * @param name the name of this loss function * @param reduction Type of Reduction to apply to loss. */ - public SparseCategoricalCrossentropy(Ops tf, String name, Reduction reduction) { - this(tf, name, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(String name, Reduction reduction) { + this(name, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); } /** - * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy using a AbstractLoss Reduction of {@link + * AbstractLoss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public SparseCategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { - this(tf, name, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(String name, boolean fromLogits) { + this(name, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); } /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link + * name, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} and fromLogits={@link * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values */ - public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits) { - this(tf, null, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(boolean fromLogits) { + this(null, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); } /** @@ -155,8 +155,8 @@ public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits) { * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param reduction Type of Reduction to apply to loss. */ - public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits, Reduction reduction) { - this(tf, null, fromLogits, reduction, AXIS_DEFAULT); + public SparseCategoricalCrossentropy(boolean fromLogits, Reduction reduction) { + this(null, fromLogits, reduction, AXIS_DEFAULT); } /** @@ -170,8 +170,8 @@ public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits, Reduction reduc * and axis=1 corresponds to data format 'Channels First'. */ public SparseCategoricalCrossentropy( - Ops tf, String name, boolean fromLogits, Reduction reduction, int axis) { - super(tf, name, reduction); + String name, boolean fromLogits, Reduction reduction, int axis) { + super(name, reduction); this.fromLogits = fromLogits; this.axis = axis; } @@ -199,23 +199,24 @@ public SparseCategoricalCrossentropy( */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = LossesHelper.rangeCheck( - getTF(), + tf, "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + cast(tf, tf.constant(0), predictions.type()), + cast(tf, tf.constant(1), predictions.type())); } else { lPredictions = predictions; } Operand losses = - Losses.sparseCategoricalCrossentropy(getTF(), labels, lPredictions, fromLogits, axis); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + Losses.sparseCategoricalCrossentropy(tf, labels, lPredictions, fromLogits, axis); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index dadbdb3b95e..c804b463984 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -37,7 +38,7 @@ * Operand<TFloat32> predictions = * tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}}); * SquaredHinge squaredHinge = new SquaredHinge(tf); - * Operand<TFloat32> result = squaredHinge.call(labels, predictions); + * Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions); * // produces 1.86f * * @@ -45,7 +46,7 @@ * *
          *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
        - *    Operand<TFloat32> result = squaredHinge.call(labels, predictions,
        + *    Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions,
          *                                                  sampleWeight);
          *    // produces 0.73f
          * 
        @@ -53,50 +54,46 @@ *

        Using SUM reduction type: * *

        - *    SquaredHinge squaredHinge = new SquaredHinge(tf, Reduction.SUM);
        - *    Operand<TFloat32> result = squaredHinge.call(labels, predictions);
        + *    SquaredHinge squaredHinge = new SquaredHinge(Reduction.SUM);
        + *    Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions);
          *    // produces 3.72f
          * 
        * *

        Using NONE reduction type: * *

        - *    SquaredHinge squaredHinge = new SquaredHinge(tf, Reduction.NONE);
        - *    Operand<TFloat32> result = squaredHinge.call(labels, predictions);
        + *    SquaredHinge squaredHinge = new SquaredHinge(Reduction.NONE);
        + *    Operand<TFloat32> result = squaredHinge.call(Ops tf, labels, predictions);
          *    // produces [1.46f, 2.26f]
          * 
        */ -public class SquaredHinge extends Loss { +public class SquaredHinge extends AbstractLoss { /** - * Creates a Squared Hinge Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Loss#REDUCTION_DEFAULT} - * - * @param tf the TensorFlow Ops + * Creates a Squared Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name and a + * AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} */ - public SquaredHinge(Ops tf) { - super(tf); + public SquaredHinge() { + super(); } /** - * Creates a Squared Hinge Loss using {@link Class#getSimpleName()} as the loss name + * Creates a Squared Hinge AbstractLoss using {@link Class#getSimpleName()} as the loss name * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to the loss. */ - public SquaredHinge(Ops tf, Reduction reduction) { - super(tf, null, reduction); + public SquaredHinge(Reduction reduction) { + super(null, reduction); } /** * Creates a Squared Hinge * - * @param tf the TensorFlow Ops * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ - public SquaredHinge(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); + public SquaredHinge(String name, Reduction reduction) { + super(name, reduction); } /** @@ -123,19 +120,17 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + @SuppressWarnings("unchecked") - Operand tLabels = - predictions.type() == labels.type() - ? (Operand) labels - : cast(tf, labels, predictions.type()); + Operand tLabels = cast(tf, labels, predictions.type()); tLabels = LossesHelper.valueCheck( - getTF(), + tf, "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); - Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); - return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + cast(tf, tf.constant(new int[] {-1, 0, 1}), predictions.type())); + Operand losses = Losses.squaredHinge(tf, tLabels, predictions); + return LossesHelper.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/AbstractLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/AbstractLoss.java new file mode 100644 index 00000000000..9534f6fe3ad --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/AbstractLoss.java @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.losses.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.framework.losses.Reduction; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +public abstract class AbstractLoss implements Loss { + public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; + + protected final Reduction reduction; + private final String name; + + /** + * Creates a AbstractLoss using {@link Class#getSimpleName()} as the name and a AbstractLoss + * Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} + */ + protected AbstractLoss() { + this(null, Reduction.AUTO); + } + + /** + * Creates a AbstractLoss using a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} + * + * @param name the name of this AbstractLoss, if null the name will be {@link + * Class#getSimpleName()}. + */ + protected AbstractLoss(String name) { + this(name, Reduction.AUTO); + } + + /** + * Creates a AbstractLoss + * + * @param name the name of this loss, if null the name will be {@link Class#getSimpleName()}. + * @param reduction Type of Reduction to apply to the loss. + */ + protected AbstractLoss(String name, Reduction reduction) { + this.name = name == null ? getClass().getSimpleName() : name; + this.reduction = reduction; + } + + /** + * Calculates the loss + * + * @param tf the TensorFlow Ops + * @param labels the truth values or labels + * @param predictions the predictions + * @param The data type of the predictions and loss. + * @return the loss + */ + public Operand call( + Ops tf, Operand labels, Operand predictions) { + return call(tf, labels, predictions, null); + } + + /** + * Gets the loss reduction + * + * @return the loss reduction + */ + public Reduction getReduction() { + return reduction; + } + + /** + * Gets the name for this loss + * + * @return the name for this loss + */ + public String getName() { + return name; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index bc5047d5855..69cb2ee0dfe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -40,26 +40,26 @@ /** * Metric that computes the approximate AUC (Area under the curve) via a Riemann sum. * - *

        This metric creates four local variables, {@code truePositives}, {@code trueNegatives - * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the - * AUC. To discretize the AUC curve, a linearly spaced set of thresholds is used to compute pairs of - * recall and precision values. The area under the ROC-curve is therefore computed using the height - * of the recall values by the false positive rate, while the area under the PR-curve is the - * computed using the height of the precision values by the recall. + *

        This metric creates four local variables, {@code truePositives}, {@code trueNegatives }, + * {@code falsePositives} and {@code falseNegatives} that are used to compute the AUC. To discretize + * the AUC curve, a linearly spaced set of thresholds is used to compute pairs of recall and + * precision values. The area under the ROC-curve is therefore computed using the height of the + * recall values by the false positive rate, while the area under the PR-curve is the computed using + * the height of the precision values by the recall. * - *

        This value is ultimately returned as {@code auc}, an idempotent operation that computes - * the area under a discretized curve of precision versus recall values (computed using the + *

        This value is ultimately returned as {@code auc}, an idempotent operation that computes the + * area under a discretized curve of precision versus recall values (computed using the * aforementioned variables). The {@code numThresholds} variable controls the degree of * discretization with larger numbers of thresholds more closely approximating the true AUC. The - * quality of the approximation may vary dramatically depending on {@code numThresholds}. The - * {@code thresholds} parameter can be used to manually specify thresholds which split the - * predictions more evenly. + * quality of the approximation may vary dramatically depending on {@code numThresholds}. The {@code + * thresholds} parameter can be used to manually specify thresholds which split the predictions more + * evenly. * - *

        For best results, {@code predictions} should be distributed approximately uniformly in - * the range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor - * if this is not the case. Setting {@code summationMethod} to {@code minoring} or {@code - * majoring} can help quantify the error in the approximation by providing lower or upper - * bound estimate of the AUC. + *

        For best results, {@code predictions} should be distributed approximately uniformly in the + * range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor if + * this is not the case. Setting {@code summationMethod} to {@code minoring} or {@code majoring} can + * help quantify the error in the approximation by providing lower or upper bound estimate of the + * AUC. * *

        Usage:
        * @@ -155,8 +155,8 @@ public class AUC extends Metric { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link #DEFAULT_NUM_THRESHOLDS} for the numThresholds, {@link AUCCurve#ROC} for the curve type, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for - * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed @@ -180,8 +180,8 @@ public AUC(Ops tf, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NUM_THRESHOLDS} for the * numThresholds, {@link AUCCurve#ROC} for the curve type, {@link - * AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, - * {@code false} for multiLabel, and {@code null} for labelWeights. + * AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, {@code + * false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -206,8 +206,8 @@ public AUC(Ops tf, String name, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the - * summation method, {@code null} for thresholds, {@code false} for multiLabel, and - * {@code null} for labelWeights. + * summation method, {@code null} for thresholds, {@code false} for multiLabel, and {@code null} + * for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -233,8 +233,8 @@ public AUC(Ops tf, int numThresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the - * summation method, {@code null} for numThresholds, {@code false} for multiLabel, and - * {@code null} for labelWeights. + * summation method, {@code null} for numThresholds, {@code false} for multiLabel, and {@code + * null} for labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -259,8 +259,8 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric. using {@link AUCCurve#ROC} for the curve type, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for - * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -314,8 +314,8 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link AUCSummationMethod#INTERPOLATION} for - * the summation method, {@code null} for thresholds, {@code false} for multiLabel, and - * {@code null} for labelWeights. + * the summation method, {@code null} for thresholds, {@code false} for multiLabel, and {@code + * null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -372,8 +372,8 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for - * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -400,8 +400,8 @@ public AUC(Ops tf, int numThresholds, AUCCurve curve, long seed, Class type) /** * Creates an AUC (Area under the curve) metric using {@code null} for numThresholds, {@link - * AUCSummationMethod#INTERPOLATION} for the summation method, {@code false} for multiLabel, - * and {@code null} for labelWeights. + * AUCSummationMethod#INTERPOLATION} for the summation method, {@code false} for multiLabel, and + * {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -428,8 +428,7 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) /** * Creates an AUC (Area under the curve) metric. using {@link #DEFAULT_NAME} for the metric name,, - * {@code null} for thresholds, {@code false} for multiLabel, and {@code null} for - * labelWeights. + * {@code null} for thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -453,8 +452,8 @@ public AUC( /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, - * {@code null} for numThresholds, {@code false} for multiLabel, and {@code null} - * for labelWeights. + * {@code null} for numThresholds, {@code false} for multiLabel, and {@code null} for + * labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -487,8 +486,8 @@ public AUC( } /** - * Creates an AUC (Area under the curve) metric. using {@code null} for thresholds, {@code - * false} for multiLabel, and {@code null} for labelWeights. + * Creates an AUC (Area under the curve) metric. using {@code null} for thresholds, {@code false} + * for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -513,8 +512,8 @@ public AUC( } /** - * Creates an AUC (Area under the curve) metric. using {@code null} for the numThresholds, - * {@code false} for multiLabel, and {@code null} for labelWeights. + * Creates an AUC (Area under the curve) metric. using {@code null} for the numThresholds, {@code + * false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} @@ -560,16 +559,16 @@ public AUC( * @param summationMethod Specifies the Riemann summation method used * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. This method - * automatically brackets the provided {@code thresholds} with a (-{@link #EPSILON}) - * below and a (1 + {@link #EPSILON}) above. + * automatically brackets the provided {@code thresholds} with a (-{@link #EPSILON}) below and + * a (1 + {@link #EPSILON}) above. * @param multiLabel boolean indicating whether multilabel data should be treated as such, wherein * AUC is computed separately for each label and then averaged across labels, or (when false) * if the data should be flattened into a single label before AUC computation. In the latter * case, when multilabel data is passed to AUC, each label-prediction pair is treated as an * individual data point. Should be set to {@code false} for multi-class data. * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When {@code - * multiLabel} is true, the weights are applied to the individual label AUCs when they - * are averaged to produce the multi-label AUC. When it's false, they are used to weight the + * multiLabel} is true, the weights are applied to the individual label AUCs when they are + * averaged to produce the multi-label AUC. When it's false, they are used to weight the * individual label predictions in computing the confusion matrix on the flattened data. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -684,8 +683,8 @@ private Map> build(Shape shape) { } // Create metric variables - Zeros zeros = new Zeros<>(tf); - Operand zero = zeros.call(tf.constant(variableShape), type); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(tf, tf.constant(variableShape), type); if (truePositives == null) { truePositives = tf.withName(getTruePositivesName()).variable(zero); initializers.put(ConfusionMatrixEnum.TRUE_POSITIVES, tf.assign(truePositives, zero)); @@ -715,8 +714,8 @@ private Map> build(Shape shape) { * * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more class * dimensions, and L1 is a potential extra dimension of size 1 that would be squeezed. Will be - * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} {@code != null - * }, then Cx must be a single dimension. + * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} {@code != null }, + * then Cx must be a single dimension. * @param predictions the predictions shape (N, Cx, P1?). Will be cast to {@code T}. * @param sampleWeights sample weights to be applied to values, may be null. Will be cast to * {@code }. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java index 516d6c91ba6..b8ec681cbfc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java @@ -29,12 +29,10 @@ * Metric that calculates how often predictions equals labels. * *

        This metric creates two local variables, total and count that are used to compute the - * frequency with which {@code predictions} matches {@code labels}. This frequency is - * ultimately returned as binary accuracy: an idempotent operation that simply divides total by - * count. + * frequency with which {@code predictions} matches {@code labels}. This frequency is ultimately + * returned as binary accuracy: an idempotent operation that simply divides total by count. * - *

        If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask - * values. + *

        If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask values. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java index 0e41699e165..a03677efd43 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -26,12 +26,10 @@ * Metric that calculates how often predictions matches binary labels. * *

        This metric creates two local variables, total and count that are used to compute the - * frequency with which {@code predictions} matches {@code labels}. This frequency is - * ultimately returned as binary accuracy: an idempotent operation that simply divides total by - * count. + * frequency with which {@code predictions} matches {@code labels}. This frequency is ultimately + * returned as binary accuracy: an idempotent operation that simply divides total by count. * - *

        If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask - * values. + *

        If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask values. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java index dece2d1cd50..0cd90325e32 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -27,18 +27,17 @@ /** * Metric that calculates how often predictions matches one-hot labels. * - *

        You can provide {@code logits} of classes as {@code predictions}, since argmax of - * {@code logits} and probabilities are same. + *

        You can provide {@code logits} of classes as {@code predictions}, since argmax of {@code + * logits} and probabilities are same. * - *

        This metric creates two local variables, {@code total} and {@code count} that are - * used to compute the frequency with which {@code predictions} matches {@code labels}. - * This frequency is ultimately returned as categorical accuracy: an idempotent operation that - * simply divides total by count. + *

        This metric creates two local variables, {@code total} and {@code count} that are used to + * compute the frequency with which {@code predictions} matches {@code labels}. This frequency is + * ultimately returned as categorical accuracy: an idempotent operation that simply divides total by + * count. * - *

        {@code predictions} and {@code labels} should be passed in as vectors of - * probabilities, rather than as labels. If necessary, use {@link - * org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, Operand, OneHot.Options...)} to expand - * {@code labels} as a vector. + *

        {@code predictions} and {@code labels} should be passed in as vectors of probabilities, rather + * than as labels. If necessary, use {@link org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, + * Operand, OneHot.Options...)} to expand {@code labels} as a vector. * *

        If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index 58aa51f664c..4a32981aeeb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -29,8 +29,7 @@ * *

        This is the crossentropy metric class to be used when there are multiple label classes (2 or * more). The labels should be given as a one_hot representation. eg., When labels values are {@code - * [2, 0, 1]}, the labels Operand contains = {@code [[0, 0, 1], [1, 0, 0], [0, 1, 0]] - * }. + * [2, 0, 1]}, the labels Operand contains = {@code [[0, 0, 1], [1, 0, 0], [0, 1, 0]] }. * * @param The data type for the metric result */ @@ -52,9 +51,9 @@ public class CategoricalCrossentropy extends MeanMetricWrappe * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to * a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, - * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} - * means that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 - * } for label {@code 1} + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} means + * that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 } for label + * {@code 1} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result @@ -73,13 +72,12 @@ public CategoricalCrossentropy( * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a * probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, - * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} - * means that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 - * } for label {@code 1} + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} means + * that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 } for label + * {@code 1} * @param axis Int specifying the channels axis. {@code axis={@link Losses#CHANNELS_LAST}} - * corresponds to data format {@code channels_last}, and {@code - * axis={@link Losses#CHANNELS_FIRST}} corresponds to data format {@code - * channels_first}. + * corresponds to data format {@code channels_last}, and {@code axis={@link + * Losses#CHANNELS_FIRST}} corresponds to data format {@code channels_first}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java index 3db7fffc2e9..9f957ee6c17 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of false negatives. * - *

        If {@code sampleWeights} is given, calculates the sum of the weights of false negatives. - * This metric creates one local variable, {@code accumulator} that is used to keep track of - * the number of false negatives. + *

        If {@code sampleWeights} is given, calculates the sum of the weights of false negatives. This + * metric creates one local variable, {@code accumulator} that is used to keep track of the number + * of false negatives. * - *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code - * sampleWeights} of 0 to mask values. + *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code sampleWeights} of 0 + * to mask values. * * @param The data type for the metric result */ @@ -50,10 +50,10 @@ public FalseNegatives(Ops tf, long seed, Class type) { * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -66,10 +66,10 @@ public FalseNegatives(Ops tf, float threshold, long seed, Class type) { * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -96,10 +96,10 @@ public FalseNegatives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -113,10 +113,10 @@ public FalseNegatives(Ops tf, String name, float threshold, long seed, Class * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java index 551529b6179..a3d585dea0f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of false positives. * - *

        If {@code sampleWeights} is given, calculates the sum of the weights of false positives. - * This metric creates one local variable, {@code accumulator} that is used to keep track of - * the number of false positives. + *

        If {@code sampleWeights} is given, calculates the sum of the weights of false positives. This + * metric creates one local variable, {@code accumulator} that is used to keep track of the number + * of false positives. * - *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code - * sampleWeights} of 0 to mask values. + *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code sampleWeights} of 0 + * to mask values. * * @param The data type for the metric result */ @@ -50,10 +50,10 @@ public FalsePositives(Ops tf, long seed, Class type) { * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -66,10 +66,10 @@ public FalsePositives(Ops tf, float threshold, long seed, Class type) { * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -96,10 +96,10 @@ public FalsePositives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -113,10 +113,10 @@ public FalsePositives(Ops tf, String name, float threshold, long seed, Class * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 22baab3d6cb..04f4deb81cf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -93,11 +93,15 @@ private void init() { Shape variableShape = Shape.of(numClasses, numClasses); if (totalConfusionMatrix == null) { - Zeros zeros = new Zeros<>(getTF()); + Zeros zeros = new Zeros<>(); totalConfusionMatrix = - getTF().withName(totalCMName).variable(zeros.call(getTF().constant(variableShape), type)); + getTF() + .withName(totalCMName) + .variable(zeros.call(getTF(), getTF().constant(variableShape), type)); initializer = - getTF().assign(totalConfusionMatrix, zeros.call(getTF().constant(variableShape), type)); + getTF() + .assign( + totalConfusionMatrix, zeros.call(getTF(), getTF().constant(variableShape), type)); } } @@ -124,8 +128,8 @@ public Assign getInitializer() { * @param sampleWeights Optional weighting of each example. Defaults to 1, if null. Rank is either * 0, or the same rank as labels, and must be broadcastable to labels. * @return the Operands that updates totalConfusionMatrix variable - * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} labels rank, - * and if the predictions size is not equal to the labels size + * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} + * labels rank, and if the predictions size is not equal to the labels size */ @Override public List updateStateList( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index acf28f5b2cc..8d92b97ec5f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -28,13 +28,12 @@ /** * Computes the mean relative error by normalizing with the given values. * - *

        This metric creates two local variables, {@code total} and {@code count} that are - * used to compute the mean relative error. This is weighted by {@code sampleWeight}, and it is - * ultimately returned as mean relative error: an idempotent operation that simply divides total by - * count. + *

        This metric creates two local variables, {@code total} and {@code count} that are used to + * compute the mean relative error. This is weighted by {@code sampleWeight}, and it is ultimately + * returned as mean relative error: an idempotent operation that simply divides total by count. * - *

        If {@code sampleWeight} is {@code null}, weights default to 1. Use {@code sampleWeight} - * of 0 to mask values. + *

        If {@code sampleWeight} is {@code null}, weights default to 1. Use {@code sampleWeight} of 0 + * to mask values. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java index d88d7a4c1b4..583d9b2dde7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -85,8 +85,8 @@ public MeanTensor(Ops tf, String name, long seed, Class type) { private boolean init(Shape shape) { if (!initialized) { this.shape = shape; - Zeros zeros = new Zeros<>(getTF()); - Operand zero = zeros.call(getTF().constant(shape), type); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(getTF(), getTF().constant(shape), type); if (total == null) { total = getTF().withName(totalName).variable(zero); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index 3812e799b75..f81b32e8d76 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -36,22 +36,22 @@ /** * Computes the precision of the predictions with respect to the labels. * - *

        The metric creates two local variables, {@code truePositives} and {@code falsePositives - * } that are used to compute the precision. This value is ultimately returned as precision, - * an idempotent operation that simply divides {@code truePositives} by the sum of {@code - * truePositives} and {@code falsePositives}. + *

        The metric creates two local variables, {@code truePositives} and {@code falsePositives } that + * are used to compute the precision. This value is ultimately returned as precision, an idempotent + * operation that simply divides {@code truePositives} by the sum of {@code truePositives} and + * {@code falsePositives}. * - *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of - * 0 to mask values. + *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of 0 to mask + * values. * - *

        If {@code topK} is set, the metric calculates precision as how often on average a class - * among the top-k classes with the highest predicted values of a batch entry is correct and can be - * found in the label for that entry. + *

        If {@code topK} is set, the metric calculates precision as how often on average a class among + * the top-k classes with the highest predicted values of a batch entry is correct and can be found + * in the label for that entry. * *

        If {@code classId} is specified, the metric calculates precision by considering only the - * entries in the batch for which {@code classId} is above the {@code thresholds} and/or - * in the top-k highest predictions, and computing the fraction of them for which {@code classId - * } is indeed a correct label. + * entries in the batch for which {@code classId} is above the {@code thresholds} and/or in the + * top-k highest predictions, and computing the fraction of them for which {@code classId } is + * indeed a correct label. * * @param The data type for the metric result */ @@ -103,10 +103,9 @@ public Precision(Ops tf, String name, long seed, Class type) { * values. * * @param tf the TensorFlow Ops - * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is - * compared with prediction values to determine the truth value of predictions (i.e., above - * the threshold is true, below is false). One metric value is generated for each threshold - * value. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated for each threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -138,10 +137,9 @@ public Precision(Ops tf, float[] thresholds, long seed, Class type) { * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is - * compared with prediction values to determine the truth value of predictions (i.e., above - * the threshold is true, below is false). One metric value is generated for each threshold - * value. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated for each threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -172,10 +170,9 @@ public Precision(Ops tf, String name, float[] thresholds, long seed, Class ty * Creates a Precision Metric with a name of {@link Class#getSimpleName()} * * @param tf the TensorFlow Ops - * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is - * compared with prediction values to determine the truth value of predictions (i.e., above - * the threshold is true, below is false). One metric value is generated for each threshold - * value. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated for each threshold value. * @param topK An optional value specifying the top-k predictions to consider when calculating * precision. * @param classId Optional Integer class ID for which we want binary metrics. This must be in the @@ -216,10 +213,9 @@ public Precision( * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is - * compared with prediction values to determine the truth value of predictions (i.e., above - * the threshold is true, below is false). One metric value is generated for each threshold - * value. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated for each threshold value. * @param topK An optional value specifying the top-k predictions to consider when calculating * precision. * @param classId Optional Integer class ID for which we want binary metrics. This must be in the @@ -280,17 +276,15 @@ public Precision( /** Initializes the variables */ private void init() { Ops tf = getTF(); - Zeros zeros = new Zeros<>(tf); - Operand zero = zeros.call(tf.constant(Shape.of(thresholds.length)), type); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(tf, tf.constant(Shape.of(thresholds.length)), type); if (this.truePositives == null) { this.truePositives = tf.withName(truePositivesName).variable(zero); initializers.add(tf.assign(truePositives, zero)); } if (this.falsePositives == null) { - this.falsePositives = - tf.withName(falsePositivesName) - .variable(zero); + this.falsePositives = tf.withName(falsePositivesName).variable(zero); initializers.add(tf.assign(falsePositives, zero)); } } @@ -340,11 +334,12 @@ public List updateStateList( public Operand result() { Ops tf = getTF(); Operand result = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); - return thresholds.length == 1 - ? tf.reshape(tf.slice( - result, - tf.expandDims(tf.constant(0), tf.constant(0)), - tf.expandDims(tf.constant(1), tf.constant(0))), + return thresholds.length == 1 + ? tf.reshape( + tf.slice( + result, + tf.expandDims(tf.constant(0), tf.constant(0)), + tf.expandDims(tf.constant(1), tf.constant(0))), tf.constant(Shape.scalar())) : result; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java index 5f5f9b47a10..0bb49378f5b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -29,8 +29,8 @@ * falseNegatives that are used to compute the precision at the given recall. The threshold for the * given recall value is computed and used to evaluate the corresponding precision. * - *

        If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of - * 0 to mask values. + *

        If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of 0 to mask + * values. * * @param The data type for the metric result */ @@ -115,8 +115,7 @@ public PrecisionAtRecall( public Operand result() { Ops tf = getTF(); - Operand div = - tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); + Operand div = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); Operand sub = tf.math.sub(div, cast(tf, tf.constant(recall), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); minIndex = tf.expandDims(minIndex, tf.constant(0)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java index 3886ec050b0..2780add994f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -36,20 +36,20 @@ /** * Computes the recall of the predictions with respect to the labels. * - *

        This metric creates two local variables, {@code truePositives} and {@code falseNegatives - * }, that are used to compute the recall. This value is ultimately returned as recall, an - * idempotent operation that simply divides {@code truePositives} by the sum of {@code - * truePositives} and {@code falseNegatives}. + *

        This metric creates two local variables, {@code truePositives} and {@code falseNegatives }, + * that are used to compute the recall. This value is ultimately returned as recall, an idempotent + * operation that simply divides {@code truePositives} by the sum of {@code truePositives} and + * {@code falseNegatives}. * - *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of - * 0 to mask values. + *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of 0 to mask + * values. * - *

        If {@code topK} is set, the metric calculates recall as how often on average a class - * among the labels of a batch entry is in the top-k predictions. + *

        If {@code topK} is set, the metric calculates recall as how often on average a class among the + * labels of a batch entry is in the top-k predictions. * - *

        If {@code classId} is specified, the metric calculates recall by considering only the - * entries in the batch for which {@code classId} is in the label, and computing the fraction - * of them for which {@code classId} is above the threshold and/or in the top-k predictions. + *

        If {@code classId} is specified, the metric calculates recall by considering only the entries + * in the batch for which {@code classId} is in the label, and computing the fraction of them for + * which {@code classId} is above the threshold and/or in the top-k predictions. * * @param The data type for the metric result */ @@ -305,8 +305,8 @@ public Recall( /** Initializes the Variables */ private void init() { Ops tf = getTF(); - Zeros zeros = new Zeros<>(tf); - Operand zero = zeros.call(tf.constant(Shape.of(this.thresholds.length)), type); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(tf, tf.constant(Shape.of(this.thresholds.length)), type); if (truePositives == null) { truePositives = tf.withName(truePositivesName).variable(zero); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java index a3fc2f77b7f..e54def48fce 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -34,8 +34,8 @@ * falseNegatives that are used to compute the recall at the given precision. The threshold for the * given precision value is computed and used to evaluate the corresponding recall. * - *

        If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of - * 0 to mask values. + *

        If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of 0 to mask + * values. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java index 3886428425b..0d140eb96b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -27,8 +27,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes root mean squared error metric between {@code labels} and {@code predictions} - * . + * Computes root mean squared error metric between {@code labels} and {@code predictions} . * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java index 29c0504b823..23a529ae1bb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -25,19 +25,19 @@ /** * Computes best sensitivity where sensitivity is >= specified value. * - *

        {@code Sensitivity} measures the proportion of actual positives that are correctly - * identified as such {@code (tp / (tp + fn))}. + *

        {@code Sensitivity} measures the proportion of actual positives that are correctly identified + * as such {@code (tp / (tp + fn))}. * - *

        {@code Specificity} measures the proportion of actual negatives that are correctly - * identified as such {@code (tn / (tn + fp))}. + *

        {@code Specificity} measures the proportion of actual negatives that are correctly identified + * as such {@code (tn / (tn + fp))}. * - *

        This metric creates four local variables, {@code truePositives}, {@code trueNegatives - * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the - * sensitivity at the given specificity. The threshold for the given specificity value is computed - * and used to evaluate the corresponding sensitivity. + *

        This metric creates four local variables, {@code truePositives}, {@code trueNegatives }, + * {@code falsePositives} and {@code falseNegatives} that are used to compute the sensitivity at the + * given specificity. The threshold for the given specificity value is computed and used to evaluate + * the corresponding sensitivity. * - *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of - * 0 to mask values. + *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of 0 to mask + * values. * * @see Additional information * about specificity and sensitivity diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 5294f798044..1d017ddf8fb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -35,9 +35,9 @@ * probabilities are same. * *

        This metric creates two local variables, `total` and `count` that are used to compute the - * frequency with which {@code predictions} matches {@code labels}. This frequency is - * ultimately returned as `sparse categorical accuracy`: an idempotent operation that simply divides - * `total` by `count`. + * frequency with which {@code predictions} matches {@code labels}. This frequency is ultimately + * returned as `sparse categorical accuracy`: an idempotent operation that simply divides `total` by + * `count`. * *

        If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values.' * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java index 2cb7e54eba0..95d46c8fd06 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -24,19 +24,19 @@ /** * Computes best specificity where sensitivity is >= specified value. {@code Sensitivity} - * measures the proportion of actual positives that are correctly identified as such {@code - * (tp / (tp + fn))}. + * measures the proportion of actual positives that are correctly identified as such {@code (tp / + * (tp + fn))}. * - *

        {@code Specificity} measures the proportion of actual negatives that are correctly - * identified as such {@code (tn / (tn + fp))}. + *

        {@code Specificity} measures the proportion of actual negatives that are correctly identified + * as such {@code (tn / (tn + fp))}. * - *

        This metric creates four local variables, {@code truePositives}, {@code trueNegatives - * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the - * specificity at the given sensitivity. The threshold for the given sensitivity value is computed - * and used to evaluate the corresponding specificity. + *

        This metric creates four local variables, {@code truePositives}, {@code trueNegatives }, + * {@code falsePositives} and {@code falseNegatives} that are used to compute the specificity at the + * given sensitivity. The threshold for the given sensitivity value is computed and used to evaluate + * the corresponding specificity. * - *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of - * 0 to mask values. + *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of 0 to mask + * values. * * @see Additional information * about specificity and sensitivity diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java index 637ca6cdd05..bcb1d7b9a36 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java @@ -21,11 +21,11 @@ /** * Computes the (weighted) sum of the given values. * - *

        For example, if values is {@code [1, 3, 5, 7]} then the sum is {@code 16}. If the - * weights were specified as {@code [1, 1, 0, 0]}, then the sum would be {@code 4.} + *

        For example, if values is {@code [1, 3, 5, 7]} then the sum is {@code 16}. If the weights were + * specified as {@code [1, 1, 0, 0]}, then the sum would be {@code 4.} * - *

        This metric creates one variable, {@code total}, that is used to compute the sum of - * values. This is ultimately returned as sum. + *

        This metric creates one variable, {@code total}, that is used to compute the sum of values. + * This is ultimately returned as sum. * *

        If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java index 0146552433f..b6e50c3295a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -34,8 +34,8 @@ public class TopKCategoricalAccuracy extends MeanMetricWrappe private final int k; /** - * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for {@code k}, Number of - * top elements to look at for computing accuracy. + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for {@code k}, Number of top + * elements to look at for computing accuracy. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java index 5c65f8c469f..fd6b95df6d2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of true negatives. * - *

        If {@code sampleWeights} is given, calculates the sum of the weights of true negatives. - * This metric creates one local variable, {@code accumulator} that is used to keep track of - * the number of true negatives. + *

        If {@code sampleWeights} is given, calculates the sum of the weights of true negatives. This + * metric creates one local variable, {@code accumulator} that is used to keep track of the number + * of true negatives. * - *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code - * sampleWeights} of 0 to mask values. + *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code sampleWeights} of 0 + * to mask values. * * @param The data type for the metric result */ @@ -50,10 +50,10 @@ public TrueNegatives(Ops tf, long seed, Class type) { * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -66,10 +66,10 @@ public TrueNegatives(Ops tf, float threshold, long seed, Class type) { * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -96,10 +96,10 @@ public TrueNegatives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -113,10 +113,10 @@ public TrueNegatives(Ops tf, String name, float threshold, long seed, Class t * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java index f0dd8c42de5..90fe9142014 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of true positives. * - *

        If {@code sampleWeights} is given, calculates the sum of the weights of true positives. - * This metric creates one local variable, {@code accumulator} that is used to keep track of - * the number of true positives. + *

        If {@code sampleWeights} is given, calculates the sum of the weights of true positives. This + * metric creates one local variable, {@code accumulator} that is used to keep track of the number + * of true positives. * - *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code - * sampleWeights} of 0 to mask values. + *

        If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code sampleWeights} of 0 + * to mask values. * * @param The data type for the metric result */ @@ -50,10 +50,10 @@ public TruePositives(Ops tf, long seed, Class type) { * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -66,10 +66,10 @@ public TruePositives(Ops tf, float threshold, long seed, Class type) { * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -96,10 +96,10 @@ public TruePositives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -113,10 +113,10 @@ public TruePositives(Ops tf, String name, float threshold, long seed, Class t * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared - * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is {@code true}, below is {@code false}). One metric value is generated - * for each threshold value + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each threshold + * value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java index 88597cf85ec..b031d80d0ef 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -67,10 +67,9 @@ public ConfusionMatrixConditionCount( * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate - * @param threshold a threshold value in {@code [0, 1]}. A threshold is compared with - * prediction values to determine the truth value of predictions (i.e., above the threshold is - * {@code true}, below is {@code false}). One metric value is generated for each - * threshold value. + * @param threshold a threshold value in {@code [0, 1]}. A threshold is compared with prediction + * values to determine the truth value of predictions (i.e., above the threshold is {@code + * true}, below is {@code false}). One metric value is generated for each threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -91,10 +90,9 @@ public ConfusionMatrixConditionCount( * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate - * @param thresholds threshold values in {@code [0, 1]}. A threshold is compared with - * prediction values to determine the truth value of predictions (i.e., above the threshold is - * {@code true}, below is {@code false}). One metric value is generated for each - * threshold value. + * @param thresholds threshold values in {@code [0, 1]}. A threshold is compared with prediction + * values to determine the truth value of predictions (i.e., above the threshold is {@code + * true}, below is {@code false}). One metric value is generated for each threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables @@ -118,12 +116,13 @@ public ConfusionMatrixConditionCount( private void init() { Shape variableShape = Shape.of(this.thresholds.length); - Zeros zeros = new Zeros<>(getTF()); + Zeros zeros = new Zeros<>(); accumulator = getTF() .withName(getAccumulatorName()) - .variable(zeros.call(getTF().constant(variableShape), type)); - initializer = getTF().assign(accumulator, zeros.call(getTF().constant(variableShape), type)); + .variable(zeros.call(getTF(), getTF().constant(variableShape), type)); + initializer = + getTF().assign(accumulator, zeros.call(getTF(), getTF().constant(variableShape), type)); } /** @@ -189,7 +188,10 @@ public float[] getThresholds() { return this.thresholds; } - /** @return the accumulatorName */ + /** + * Gets the accumulatorName + * @return the accumulatorName + */ public String getAccumulatorName() { return accumulatorName; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index f89047e457d..76c21aebefc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -18,7 +18,7 @@ import org.tensorflow.types.family.TNumber; /** - * Interface for Metrics that wrap Loss functions. + * Interface for Metrics that wrap AbstractLoss functions. * * @param The data type of the predictions. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 37bdd5849ae..ec103197709 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -29,9 +29,9 @@ * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of * {@link MetricReduction#WEIGHTED_MEAN}. * - *

        The loss function calculates the loss between the {@code labels} and {@code predictions - * } then passes this loss to the {@link Mean} metric to calculate the weighted mean of the - * loss over many iterations or epochs + *

        The loss function calculates the loss between the {@code labels} and {@code predictions } then + * passes this loss to the {@link Mean} metric to calculate the weighted mean of the loss over many + * iterations or epochs * * @param The data type for the metric result */ @@ -63,7 +63,7 @@ public LossMetric getLoss() { } /** - * Sets the Loss function for this wrapper. + * Sets the AbstractLoss function for this wrapper. * * @param loss the loss function. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 40336233d21..51b8836ec83 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -59,8 +59,7 @@ public class MetricsHelper { "weights can not be broadcast to values."; /** - * Asserts that the {@code sampleWeights} can be broadcast to the same shape as {@code values - * } + * Asserts that the {@code sampleWeights} can be broadcast to the same shape as {@code values } * *

        In losses and metrics, limited weight broadcasting is supported. Weights must be either * scalar, or the same rank as the target values, with each dimension either 1, or the same as the @@ -69,8 +68,8 @@ public class MetricsHelper { * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. * @param values the values to which weights are applied. - * @return {@code Operation} with control dependencies to ensure {@code sampleWeight} - * can be broadcast to {@code values} + * @return {@code Operation} with control dependencies to ensure {@code sampleWeight} can be + * broadcast to {@code values} * @param the type of Operand * @throws NotBroadcastableException If static checks determine {@code sampleWeights} has an * incorrect shape that prohibit broadcasting to {@code values} @@ -114,10 +113,7 @@ public static Op assertBroadcastable( throw new NotBroadcastableException( String.format( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", - ASSERT_BROADCAST_ERROR_PREFIX, - i, - valuesShapeStatic, - weightsShapeStatic)); + ASSERT_BROADCAST_ERROR_PREFIX, i, valuesShapeStatic, weightsShapeStatic)); } } return tf.withSubScope("staticDimsCheckSuccess") @@ -307,24 +303,24 @@ public static List assertShapes( *

        For estimation of these metrics over a stream of data, the function creates an `update_op` * operation that updates the given variables. * - *

        {@code labels}, {@code predictions}, and {@code sampleWeight} tensors are - * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. {@code - * sampleWeight} is then broadcast to the shape of {@code predictions}. + *

        {@code labels}, {@code predictions}, and {@code sampleWeight} tensors are aligned by {@link + * LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. {@code sampleWeight} is then + * broadcast to the shape of {@code predictions}. * * @param tf the TensorFlow Ops * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding variables to update as values. If {@code multiLabel}, then the variable * shapes are (T, D), where T is the number of thresholds and D is the number of classes - * (after slicing by {@code classIndex}, if provided). If {@code multiLabels}, then - * the variable shapes are (T). + * (after slicing by {@code classIndex}, if provided). If {@code multiLabels}, then the + * variable shapes are (T). * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding initializer Operands to for {@code variablesToUpdate}. * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra * dimension of size 1 that would be squeezed. * @param predictions the predictions shape (N, Cx, P1?) - * @param thresholds thresholds in the range {@code [0, 1]}, or {@link #NEG_INF} is used when - * topK is set + * @param thresholds thresholds in the range {@code [0, 1]}, or {@link #NEG_INF} is used when topK + * is set * @param topK optional, indicates that only the top k predictions should be considered. Applied * before possibly slicing by {@code classIndex}. * @param classIndex optional, limits the prediction and labels to the specified class. This is an @@ -338,14 +334,14 @@ public static List assertShapes( * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES * without explicit multilabel handling (i.e. when the data is to be flattened). Must have - * shape (Dx), which is the same as (Cx) referenced above, except that if {@code classIndex - * } is provided, then the final dimension of Dx is 1. These weights will be broadcast - * across the 0th dimension (the examples dimension) of {@code predictions}. May be null. - * Must be null if {@code multiLabel}. + * shape (Dx), which is the same as (Cx) referenced above, except that if {@code classIndex } + * is provided, then the final dimension of Dx is 1. These weights will be broadcast across + * the 0th dimension (the examples dimension) of {@code predictions}. May be null. Must be + * null if {@code multiLabel}. * @param the data type for the variables - * @throws IllegalArgumentException If {@code predictions} and {@code labels} have - * mismatched shapes, or if {@code sampleWeight} is not null and its shape - * doesn't match {@code predictions}, or if {@code multiLabel && labelWeights != null}.. + * @throws IllegalArgumentException If {@code predictions} and {@code labels} have mismatched + * shapes, or if {@code sampleWeight} is not null and its shape doesn't match {@code + * predictions}, or if {@code multiLabel && labelWeights != null}.. * @return an op to update the given confusion matrix variables. */ @SuppressWarnings({"unchecked", "rawtypes"}) @@ -439,11 +435,13 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), if (classIndex != null) { // Slice to new shapes (N, Dx) - tLabels = tf.squeeze(tf.gather(tLabels, - tf.constant(new int[] {classIndex}), tf.constant(-1)), + tLabels = + tf.squeeze( + tf.gather(tLabels, tf.constant(new int[] {classIndex}), tf.constant(-1)), Squeeze.axis(Collections.singletonList(1L))); - tPredictions = tf.squeeze(tf.gather(tPredictions, - tf.constant(new int[] {classIndex}), tf.constant(-1)), + tPredictions = + tf.squeeze( + tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(-1)), Squeeze.axis(Collections.singletonList(1L))); } org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions); @@ -693,8 +691,7 @@ private static Operand filterTopK(Ops tf, Operand x, i // alias for mean /** - * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false - * } + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -706,8 +703,8 @@ public static Operand mean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is - * {@code false} + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is {@code + * false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -725,10 +722,9 @@ public static Operand mean( * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained - * with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is {@code + * false}, the rank of the tensor is reduced by 1 for each entry in {@code axes }. If {@code + * keepdims} is {@code true}, the reduced dimensions are retained with length 1. * @param the type of the operand * @return the mean of elements of {@code x}. */ @@ -742,10 +738,9 @@ public static Operand mean(Ops tf, Operand x, boolean * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained - * with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is {@code + * false}, the rank of the tensor is reduced by 1 for each entry in {@code axes }. If {@code + * keepdims} is {@code true}, the reduced dimensions are retained with length 1. * @param the data type of the Operand * @return the mean of elements of {@code x}. */ @@ -783,12 +778,12 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( *

        For example: * *

        {@code
        -   *     confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
        -   *          [[0 0 0 0 0]
        -   *           [0 0 1 0 0]
        -   *           [0 0 1 0 0]
        -   *           [0 0 0 0 0]
        -   *           [0 0 0 0 1]]
        +   * confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
        +   *      [[0 0 0 0 0]
        +   *       [0 0 1 0 0]
        +   *       [0 0 1 0 0]
        +   *       [0 0 0 0 0]
        +   *       [0 0 0 0 1]]
            * }
        * * Note that the possible labels are assumed to be {@code [0, 1, 2, 3,4]}, resulting in a 5x5 @@ -802,12 +797,12 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( * @param weights optional weights to be applied to the confusion matrix * @param type Data type of the confusion matrix. * @param the type of Operands - * @return A {@code Operand} of type {@code type} with shape {@code [n, n]} - * representing the confusion matrix, where {@code n} is the number of possible labels in - * the classification task. - * @throws IllegalArgumentException If both {@code predictions} and {@code labels} do - * not have compatible shapes, or if {@code weights} is not{@code null} and its - * shape is not compatible with {@code predictions}. + * @return A {@code Operand} of type {@code type} with shape {@code [n, n]} representing the + * confusion matrix, where {@code n} is the number of possible labels in the classification + * task. + * @throws IllegalArgumentException If both {@code predictions} and {@code labels} do not have + * compatible shapes, or if {@code weights} is not{@code null} and its shape is not compatible + * with {@code predictions}. */ // TODO should this be moved to FramnworkOps under math. public static Operand confusionMatrix( @@ -883,8 +878,7 @@ public static Operand confusionMatrix( } /** - * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false - * } + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -895,8 +889,8 @@ public static Operand booleanMean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is - * {@code false} + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is {@code + * false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -913,10 +907,9 @@ public static Operand booleanMean( * * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained - * with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is {@code + * false}, the rank of the tensor is reduced by 1 for each entry in {@code axes }. If {@code + * keepdims} is {@code true}, the reduced dimensions are retained with length 1. * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { @@ -929,10 +922,9 @@ public static Operand booleanMean(Ops tf, Operand x, boolean ke * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained - * with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is {@code + * false}, the rank of the tensor is reduced by 1 for each entry in {@code axes }. If {@code + * keepdims} is {@code true}, the reduced dimensions are retained with length 1. * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java index 60a6c1ea3df..e47ea4ea8e8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -87,9 +87,9 @@ protected SensitivitySpecificityBase( /** Initializes the Variables */ private void init() { Ops tf = getTF(); - Zeros zeros = new Zeros<>(tf); + Zeros zeros = new Zeros<>(); Shape varShape = Shape.of(numThresholds); - Operand zero = zeros.call(tf.constant(varShape), type); + Operand zero = zeros.call(tf, tf.constant(varShape), type); if (this.getTruePositives() == null) { @@ -228,8 +228,6 @@ public int getNumThresholds() { return numThresholds; } - - /** * Gets the thresholds * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 68157632557..0553b1edac7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -26,8 +26,8 @@ public class SetsOps { /** - * Computes set difference of elements in last dimension of {@code a} and {@code b} with - * {@code aMinusB} set to true. + * Computes set difference of elements in last dimension of {@code a} and {@code b} with {@code + * aMinusB} set to true. * *

        All but the last dimension of {@code a} and {@code b} must match * @@ -35,8 +35,8 @@ public class SetsOps { * @param a The first operand representing set {@code a} * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the * same. Elements along the last dimension contain the results of the set * operation. */ public static Operand difference(Ops tf, Operand a, Operand b) { @@ -53,8 +53,8 @@ public static Operand difference(Ops tf, Operand a, Op * @param b The other operand representing set {@code b} * @param aMinusB whether to subtract b from a, vs vice versa. * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the * same. Elements along the last dimension contain the results of the set * operation. */ public static Operand difference( @@ -69,8 +69,8 @@ public static Operand difference( * @param a The first operand representing set {@code a} * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the * same. Elements along the last dimension contain the results of the set * operation. */ public static Operand union(Ops tf, Operand a, Operand b) { @@ -84,8 +84,8 @@ public static Operand union(Ops tf, Operand a, Operand * @param a The first operand representing set {@code a} * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the * same. Elements along the last dimension contain the results of the set * operation. */ public static Operand intersection(Ops tf, Operand a, Operand b) { @@ -100,8 +100,8 @@ public static Operand intersection(Ops tf, Operand a, * @param b The other et operation operand * @param setOperation The set operation to perform, {@link Operation}. * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the same. Elements along the last dimension contain the results of the set + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last + * dimension the same. Elements along the last dimension contain the results of the set * operation. */ public static Operand setOperation( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java index d28185ae041..7c3fda07ea9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java @@ -21,35 +21,72 @@ import java.util.Arrays; import java.util.List; +/** + * A class that represents a Symbolic shape. + * + *

        A Symbolic shape uses symbols to identify the relationship of the shape of an operand to + * underlying values that are not know until compute time. For example, "N" represent the number of + * examples, while "L" represents the number of labels. When the values later become known, the + * shape of the operand must conform the these symbolic values. + * + * @param The data type for the Operand. + */ public class SymbolicShape { private Operand operand; private List symbols = new ArrayList<>(); + /** + * Creates a SymbolicShape + * + * @param operand the Operand that needs to conform to the shape + * @param symbols the symbolic value for each dimension of the shape. + */ public SymbolicShape(Operand operand, String... symbols) { this.operand = operand; this.symbols.addAll(Arrays.asList(symbols)); } - /** @return the operand */ + /** + * Gets the operand + * + * @return the operand + */ public Operand getOperand() { return operand; } - /** @param operand the operand to set */ + /** + * Sets the operand + * + * @param operand the operand to set + */ public void setOperand(Operand operand) { this.operand = operand; } - /** @return the symbols */ + /** + * Gets the symbols associated with each dimension of the shape + * + * @return the symbols associated with each dimension of the shape + */ public List getSymbols() { return symbols; } - /** @param symbols the symbols to set */ + /** + * Sets teh symbols associated with each dimension of the shape + * + * @param symbols the symbols associated with each dimension of the shape + */ public void setSymbols(List symbols) { this.symbols = symbols; } + /** + * Gets the rank associated with this Symbolic Shape + * + * @return the rank associated with this Symbolic Shape + */ public int rank() { return this.symbols.size(); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index 6583465da2e..18b11700380 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -32,8 +32,8 @@ /** * Weight broadcasting operations. * - *

        In {@link org.tensorflow.framework.losses} and `{@link org.tensorflow.framework.metrics}, we support limited weight broadcasting. This file includes - * operations for those broadcasting rules. + *

        In {@link org.tensorflow.framework.losses} and `{@link org.tensorflow.framework.metrics}, we + * support limited weight broadcasting. This file includes operations for those broadcasting rules. */ public class WeightsBroadcastOps { @@ -46,10 +46,11 @@ public class WeightsBroadcastOps { * @param tf the TensorFlow Ops * @param weights the weights Operand * @param values Operand of values to which weights are applied. - * @return {@code Operation} raising a tensorflow InvalidArgumentError if {@code weights} has incorrect shape. {@link NoOp} if - * static checks determine {@code weights} has correct shape. + * @return {@code Operation} raising a tensorflow InvalidArgumentError if {@code weights} has + * incorrect shape. {@link NoOp} if static checks determine {@code weights} has correct shape. * @param the type of weights and values - * @throws IllegalArgumentException If static checks determine {@code weights} has incorrect shape. + * @throws IllegalArgumentException If static checks determine {@code weights} has incorrect + * shape. */ public static Op assertBroadcastable( Ops tf, Operand weights, Operand values) { @@ -81,14 +82,12 @@ public static Op assertBroadcastable( } for (int i = 0; i < valuesRankStatic; i++) { - if (weightsShapeStatic.size(i) != 1 && valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { + if (weightsShapeStatic.size(i) != 1 + && valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { throw new IllegalArgumentException( String.format( "%s Mismatch at dim %s. values.shape=%s weights.shape=%s.", - ASSERT_BROADCASTABLE_ERROR_PREFIX, - i, - valuesShapeStatic, - weightsShapeStatic)); + ASSERT_BROADCASTABLE_ERROR_PREFIX, i, valuesShapeStatic, weightsShapeStatic)); } } return tf.withSubScope("staticDimsCheckSuccess") @@ -105,12 +104,12 @@ public static Op assertBroadcastable( tf.constant("values.shape="), valuesShape, tf.constant("isScalar="), - isScalar); + isScalar); Operand isValidShape = tf.select( - isScalar, - isScalar, + isScalar, + isScalar, hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); return tf.assertThat(isValidShape, data); @@ -140,7 +139,8 @@ private static Operand hasValidNonscalarShape( } /** - * Checks that each dimension of the two shapes are the same size, or that the weight dimension size is 1. + * Checks that each dimension of the two shapes are the same size, or that the weight dimension + * size is 1. * * @param tf the TensorFlow Ops * @param weightsShape the shape of the weights @@ -152,7 +152,8 @@ private static Operand hasValidDims( tf = tf.withSubScope("hasInvalidDims"); Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); - Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); + Operand validDims = + tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); Operand weightsShape2d = tf.expandDims(weightsShape, tf.constant(-1)); Operand invalidDims = SetsOps.difference(tf, weightsShape2d, validDims); @@ -164,8 +165,7 @@ private static Operand hasValidDims( * Broadcast {@code weights} to the same shape as {@code values}. * *

        This returns a version of {@code weights} following the same broadcast rules as {@code - * mul(weights, - * values)}, but limited to the weights shapes allowed by {@code assertBroadcastable} + * mul(weights, values)}, but limited to the weights shapes allowed by {@code assertBroadcastable} * When computing a weighted average, use this function to broadcast {@code weights} before * summing them; e.g., {@code reduceSum(w * v) / reduceSum(_broadcast_weights(w, v))}. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/AbstractRegularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/AbstractRegularizer.java new file mode 100644 index 00000000000..25535292db3 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/AbstractRegularizer.java @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.framework.losses.impl.AbstractLoss; + +/** + * Base class for Regularizers + * + *

        Regularizers allow you to apply penalties on layer parameters or layer activity during + * optimization. These penalties are summed into the loss function that the network optimizes. + */ +public abstract class AbstractRegularizer implements Regularizer { + + public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; + + private final String name; + + /** Creates a AbstractRegularizer, using {@link Class#getSimpleName()} for the name */ + protected AbstractRegularizer() { + this(null); + } + /** + * Creates a AbstractRegularizer + * + * @param name the name of this regularizer, if null use {@link Class#getSimpleName()} for the + * name. + */ + protected AbstractRegularizer(String name) { + this.name = name == null ? this.getClass().getSimpleName() : name; + } + + /** + * Returns this AbstractRegularizer as a AbstractLoss This is a convenience to use regularize a + * loss. Only sampleWeights are applied to the regularizer. + * + * @return this AbstractRegularizer as a AbstractLoss + */ + public AbstractLoss asLoss() { + return new RegularizerLoss(this); + } + + /** + * Gets the name for this regularizer + * + * @return the name for this regularizer + */ + public String getName() { + return name; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java index 7c8c2a1360a..4b7aa1af620 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -14,8 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.regularizers; -import org.tensorflow.op.Ops; - /** * A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) * Regression, regularization penalty. @@ -24,24 +22,43 @@ */ public class L1 extends L1L2 { + /** + * Create a regularizer that applies an L1 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} and a name based on the class name. + */ + public L1() { + this(null, DEFAULT_REGULARIZATION_PENALTY); + } + /** * Create a regularizer that applies an L1 regularization penalty of {@link * #DEFAULT_REGULARIZATION_PENALTY} * - * @param tf the TensorFlow Ops + * @param name the name for this AbstractRegularizer + */ + public L1(String name) { + this(name, DEFAULT_REGULARIZATION_PENALTY); + } + + /** + * Create a regularizer that applies an L1 regularization penalty and a name based on the class + * name. + * + * @param l1 the L1 regularization penalty + * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. */ - public L1(Ops tf) { - this(tf, DEFAULT_REGULARIZATION_PENALTY); + public L1(float l1) { + this(null, l1); } /** * Create a regularizer that applies an L1 regularization penalty * - * @param tf the TensorFlow Ops + * @param name the name for this AbstractRegularizer * @param l1 the L1 regularization penalty * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. */ - public L1(Ops tf, float l1) { - super(tf, l1, 0f); + public L1(String name, float l1) { + super(name, l1, 0f); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index 29e411f9897..6dfaf3f0d47 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -19,6 +19,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A regularizer that applies both L1 and L2 regularization penalties. * @@ -29,33 +31,39 @@ *

        The L2 regularization penalty is computed as * *

        loss = l2 * reduceSum(square(x))
        - * */ -public class L1L2 extends Regularizer { +public class L1L2 extends AbstractRegularizer { private final float l1; private final float l2; + /** Creates an L1L2 regularizer with no l1 or l2 penalty with zero penalty */ + public L1L2() { + this(DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); + } + /** - * Creates an L1L2 regularizer with no l1 or l2 penalty with zero penalty + * Creates an L1L2 regularizer * - * @param tf the TensorFlow Ops + * @param l1 L1 regularization factor, if null it is set to 0. + * @param l2 L2 regularization factor, if null it is set to 0. + * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} + * of {@link Float#isInfinite} */ - public L1L2(Ops tf) { - this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); + public L1L2(float l1, float l2) { + this(null, l1, l2); } /** * Creates an L1L2 regularizer * - * @param tf the TensorFlow Ops * @param l1 L1 regularization factor, if null it is set to 0. * @param l2 L2 regularization factor, if null it is set to 0. * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} * of {@link Float#isInfinite} */ - public L1L2(Ops tf, float l1, float l2) { - super(tf); + public L1L2(String name, float l1, float l2) { + super(name); if (Float.isNaN(l1) || Float.isInfinite(l1)) { throw new IllegalArgumentException( String.format( @@ -73,25 +81,23 @@ public L1L2(Ops tf, float l1, float l2) { this.l2 = l2; } - /** {@inheritDoc} */ @Override - public Operand call(Operand input) { - Ops tf = getTF(); + public Operand call(Ops tf, Operand input) { if (this.getL1() == 0f && this.getL2() == 0f) { - return tf.dtypes.cast(tf.constant(0), input.type()); + return cast(tf, tf.constant(0), input.type()); } - Operand regularization = tf.dtypes.cast(tf.constant(0), input.type()); + Operand regularization = cast(tf, tf.constant(0), input.type()); if (this.getL1() != 0.f) { - Operand l1Op = tf.dtypes.cast(tf.constant(this.getL1()), input.type()); + Operand l1Op = cast(tf, tf.constant(this.getL1()), input.type()); Operand abs = tf.math.abs(input); Operand reduceSum = tf.reduceSum(abs, LossesHelper.allAxes(tf, input)); regularization = tf.math.add(regularization, tf.math.mul(l1Op, reduceSum)); } if (this.getL2() != 0.f) { - Operand l2Op = tf.dtypes.cast(tf.constant(this.getL2()), input.type()); + Operand l2Op = cast(tf, tf.constant(this.getL2()), input.type()); Operand sqr = tf.math.square(input); Operand reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input)); regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java index 7b8f5b28a70..9092b80b08f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -14,8 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.regularizers; -import org.tensorflow.op.Ops; - /** * A regularizer that applies a L2 (Ridge Regression) regularization penalty. * @@ -23,24 +21,43 @@ */ public class L2 extends L1L2 { + /** + * Create a regularizer that applies an L2 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} and a name based on the class name. + */ + public L2() { + this(null, DEFAULT_REGULARIZATION_PENALTY); + } + /** * Create a regularizer that applies an L2 regularization penalty of {@link * #DEFAULT_REGULARIZATION_PENALTY} * - * @param tf the TensorFlow Ops + * @param name the name for this AbstractRegularizer + */ + public L2(String name) { + this(name, DEFAULT_REGULARIZATION_PENALTY); + } + + /** + * Create a regularizer that applies an L1 regularization penalty and a name based on the class + * name. + * + * @param l2 the L2 regularization penalty + * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. */ - public L2(Ops tf) { - this(tf, DEFAULT_REGULARIZATION_PENALTY); + public L2(float l2) { + this(null, l2); } /** * Create a regularizer that applies an L1 regularization penalty * - * @param tf the TensorFlow Ops + * @param name the name for this AbstractRegularizer * @param l2 the L2 regularization penalty * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. */ - public L2(Ops tf, float l2) { - super(tf, 0f, l2); + public L2(String name, float l2) { + super(name, 0f, l2); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java index 5d9ff0e3e10..085f06e115c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,77 +15,18 @@ package org.tensorflow.framework.regularizers; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.Loss; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * Base class for Regularizers - * - *

        Regularizers allow you to apply penalties on layer parameters or layer activity during - * optimization. These penalties are summed into the loss function that the network optimizes. - */ -public abstract class Regularizer { - - public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; - - private final Ops tf; - private final String name; - - /** - * Creates a Regularizer, using {@link Class#getSimpleName()} for the name - * - * @param tf the TensorFlow ops. - */ - protected Regularizer(Ops tf) { - this(tf, null); - } - /** - * Creates a Regularizer - * - * @param tf the TensorFlow ops. - * @param name the name of this regularizer, if null use {@link Class#getSimpleName()} for the - * name. - */ - protected Regularizer(Ops tf, String name) { - this.tf = tf; - this.name = name == null ? this.getClass().getSimpleName() : name; - } - - /** - * Returns this Regularizer as a Loss This is a convenience to use regularize a loss. Only - * sampleWeights are applied to the regularizer. - * - * @return this Regularizer as a Loss - */ - public Loss asLoss() { - return new RegularizerLoss(this.tf, this); - } +public interface Regularizer { /** * Computes a regularization penalty from an input. * + * @param tf the TensorFlow Ops * @param input the weighted input * @return the result of computing the regularization penalty * @param the data type of the input and result */ - public abstract Operand call(Operand input); - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - public Ops getTF() { - return tf; - } - - /** - * Gets the name for this regularizer - * - * @return the name for this regularizer - */ - public String getName() { - return name; - } + Operand call(Ops tf, Operand input); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java index 582cd038f8f..11c7ee492e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java @@ -15,50 +15,49 @@ package org.tensorflow.framework.regularizers; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.Loss; +import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; /** - * A Regularizer call wrapped as a Loss instance + * A AbstractRegularizer call wrapped as a AbstractLoss instance * *

        This class facilitates using a regularizer as a loss, only sampleWeights are * regularized. */ -class RegularizerLoss extends Loss { +class RegularizerLoss extends AbstractLoss { - private final Regularizer regularizer; + private final AbstractRegularizer regularizer; /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link - * Loss#REDUCTION_DEFAULT} + * Creates a AbstractLoss using {@link Class#getSimpleName()} as the name and a AbstractLoss + * Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops * @param regularizer the regularizer used to calculate the loss */ - public RegularizerLoss(Ops tf, Regularizer regularizer) { - this(tf, null, regularizer); + public RegularizerLoss(AbstractRegularizer regularizer) { + this(null, regularizer); } /** - * Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * Creates a AbstractLoss using a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} * - * @param tf the TensorFlow Ops - * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. + * @param name the name of this AbstractLoss, if null the name will be {@link + * Class#getSimpleName()}. * @param regularizer the regularizer used to calculate the loss */ - public RegularizerLoss(Ops tf, String name, Regularizer regularizer) { - super(tf, name); + public RegularizerLoss(String name, AbstractRegularizer regularizer) { + super(name); this.regularizer = regularizer; } /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { if (sampleWeights == null) { throw new IllegalArgumentException("sampleWeights cannot be null"); } - return regularizer.call(sampleWeights); + return regularizer.call(tf, sampleWeights); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java index 914b94dfada..9f3fa75e95d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java @@ -14,36 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - -/** @author Jim Clarke */ public class ELUTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ELUTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of ELU call method */ @Test public void testCallFloat() { @@ -52,8 +33,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ELU instance = new ELU<>(tf); - Operand result = instance.call(tf.constant(input)); + ELU instance = new ELU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -66,8 +47,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ELU instance = new ELU<>(tf); - Operand result = instance.call(tf.constant(input)); + ELU instance = new ELU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -80,8 +61,8 @@ public void testAlpha() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ELU instance = new ELU<>(tf, 2.0f); - Operand result = instance.call(tf.constant(input)); + ELU instance = new ELU<>(2.0f); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java index 1157c582168..f82c19987d1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java @@ -14,35 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class ExponentialTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ExponentialTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of Exponential call method. */ @Test public void testCallFloat() { @@ -60,8 +42,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Exponential instance = new Exponential<>(tf); - Operand result = instance.call(tf.constant(input)); + Exponential instance = new Exponential<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -78,8 +60,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Exponential instance = new Exponential<>(tf); - Operand result = instance.call(tf.constant(input)); + Exponential instance = new Exponential<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java index 35f57c47f66..0e32201c3e6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java @@ -14,35 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class HardSigmoidTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public HardSigmoidTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of HardSigmoid call method. */ @Test public void testCallFloat() { @@ -51,8 +33,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - HardSigmoid instance = new HardSigmoid<>(tf); - Operand result = instance.call(tf.constant(input)); + HardSigmoid instance = new HardSigmoid<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -65,8 +47,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - HardSigmoid instance = new HardSigmoid<>(tf); - Operand result = instance.call(tf.constant(input)); + HardSigmoid instance = new HardSigmoid<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java index 7974035c680..817940688e8 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -26,20 +26,6 @@ public class LinearTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public LinearTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of Linear call method. */ @Test public void testCallInt() { @@ -48,8 +34,8 @@ public void testCallInt() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Linear instance = new Linear<>(tf); - Operand result = instance.call(tf.constant(input)); + Linear instance = new Linear<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -62,8 +48,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Linear instance = new Linear<>(tf); - Operand result = instance.call(tf.constant(input)); + Linear instance = new Linear<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -76,8 +62,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Linear instance = new Linear<>(tf); - Operand result = instance.call(tf.constant(input)); + Linear instance = new Linear<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java index a0aa2c4b453..94f803d6b1c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java @@ -14,30 +14,20 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; /** @author Jim Clarke */ public class ReLUTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ReLUTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of ReLU call method */ @Test public void testCallFloat() { @@ -46,8 +36,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -60,8 +50,8 @@ public void testCallInt() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -74,8 +64,8 @@ public void testCallLong() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -88,9 +78,9 @@ public void testCallFloat16() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); + ReLU instance = new ReLU<>(); Operand result = - instance.call(tf.dtypes.cast(tf.constant(input), TFloat16.class)); + instance.call(tf, tf.dtypes.cast(tf.constant(input), TFloat16.class)); session.evaluate(tf.dtypes.cast(tf.constant(expected), TFloat16.class), result); } } @@ -103,8 +93,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -112,12 +102,12 @@ public void testCallDouble() { @Test public void testAlpha() { double[] input = {-10., -5., 0.0, 5., 10.}; - double[] expected = {-5. , -2.5, 0., 5., 10.}; + double[] expected = {-5., -2.5, 0., 5., 10.}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf, 0.5f, ReLU.MAX_VALUE_DEFAULT, ReLU.THRESHOLD_DEFAULT); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(0.5f, ReLU.MAX_VALUE_DEFAULT, ReLU.THRESHOLD_DEFAULT); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -129,8 +119,8 @@ public void testMaxValue() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf, ReLU.ALPHA_DEFAULT, 5, ReLU.THRESHOLD_DEFAULT); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(ReLU.ALPHA_DEFAULT, 5, ReLU.THRESHOLD_DEFAULT); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -138,12 +128,12 @@ public void testMaxValue() { @Test public void testThreshold() { double[] input = {-10., -5., 0.0, 5., 10.}; - double[] expected = {-0., -0., 0., 0., 10.}; + double[] expected = {-0., -0., 0., 0., 10.}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf, ReLU.ALPHA_DEFAULT, ReLU.MAX_VALUE_DEFAULT, 5.0f); - Operand result = instance.call(tf.constant(input)); + ReLU instance = new ReLU<>(ReLU.ALPHA_DEFAULT, ReLU.MAX_VALUE_DEFAULT, 5.0f); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java index 8bad6f1f066..ef4644df18e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java @@ -14,35 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SELUTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SELUTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of SELU call method */ @Test public void testCallFloat() { @@ -53,8 +35,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SELU instance = new SELU<>(tf); - Operand result = instance.call(tf.constant(input)); + SELU instance = new SELU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -71,8 +53,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SELU instance = new SELU<>(tf); - Operand result = instance.call(tf.constant(input)); + SELU instance = new SELU<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java index 9dca622c3ec..0c59eeaba6e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java @@ -14,34 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SigmoidTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SigmoidTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - /** Test of Sigmoid call method */ @Test public void testCallFloat() { @@ -59,8 +42,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Sigmoid instance = new Sigmoid<>(tf); - Operand result = instance.call(tf.constant(input)); + Sigmoid instance = new Sigmoid<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -77,8 +60,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Sigmoid instance = new Sigmoid<>(tf); - Operand result = instance.call(tf.constant(input)); + Sigmoid instance = new Sigmoid<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java index 05ec3a4f716..aeb971905a2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java @@ -14,35 +14,18 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SoftmaxTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SoftmaxTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - /** Test of Softmax method, of class Activations. */ @Test public void testSoftmaxOpsOperandFloat() { @@ -54,8 +37,8 @@ public void testSoftmaxOpsOperandFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -71,8 +54,8 @@ public void testSoftmaxOpsOperandDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -88,8 +71,8 @@ public void testSoftmaxOpsOperandDoubleNegative() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -99,14 +82,14 @@ public void testSoftmaxOpsOperandDoubleNegative() { public void testSoftmax1D() { double[] input = {1, -2, 3, -4, -5, 6, 7, 8}; double[] expected = { - 6.0352829e-04, 3.0047902e-05, 4.4595040e-03, 4.0665414e-06, - 1.4959969e-06, 8.9571528e-02, 2.4348068e-01, 6.6184908e-01 + 6.0352829e-04, 3.0047902e-05, 4.4595040e-03, 4.0665414e-06, + 1.4959969e-06, 8.9571528e-02, 2.4348068e-01, 6.6184908e-01 }; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } @@ -116,14 +99,14 @@ public void testSoftmax1D() { public void testSoftmax3D() { double[][][] input = {{{1, -2}, {3, -4}}, {{-5, 6}, {-7, 8}}}; double[][][] expected = { - {{9.5257413e-01, 4.7425874e-02}, {9.9908900e-01, 9.1105123e-04}}, - {{1.6701422e-05, 9.9998331e-01}, {3.0590220e-07, 9.9999964e-01}} + {{9.5257413e-01, 4.7425874e-02}, {9.9908900e-01, 9.1105123e-04}}, + {{1.6701422e-05, 9.9998331e-01}, {3.0590220e-07, 9.9999964e-01}} }; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); - Operand result = instance.call(tf.constant(input)); + Softmax instance = new Softmax<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(tf.constant(expected), result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java index a17f2650d62..e896807d9f7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -26,20 +26,6 @@ public class SoftplusTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SoftplusTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of Softplus call method */ @Test public void testCallFloat() { @@ -50,8 +36,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softplus instance = new Softplus<>(tf); - Operand result = instance.call(tf.constant(input)); + Softplus instance = new Softplus<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -68,8 +54,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softplus instance = new Softplus<>(tf); - Operand result = instance.call(tf.constant(input)); + Softplus instance = new Softplus<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java index 43591ab4761..2f9a17caf59 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -26,20 +26,6 @@ public class SoftsignTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SoftsignTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of Softsign call method */ @Test public void testCallFloat() { @@ -48,8 +34,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softsign instance = new Softsign<>(tf); - Operand result = instance.call(tf.constant(input)); + Softsign instance = new Softsign<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -71,8 +57,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softsign instance = new Softsign<>(tf); - Operand result = instance.call(tf.constant(input)); + Softsign instance = new Softsign<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java index 7576789320b..8dabfaf379a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java @@ -14,35 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SwishTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public SwishTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - - /** Test of Swish call method */ @Test public void testCallFloat() { @@ -60,8 +42,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Swish instance = new Swish<>(tf); - Operand result = instance.call(tf.constant(input)); + Swish instance = new Swish<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -83,8 +65,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Swish instance = new Swish<>(tf); - Operand result = instance.call(tf.constant(input)); + Swish instance = new Swish<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java index 5162e141c44..3988ec55bb3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -25,20 +25,6 @@ public class TanhTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public TanhTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of Tanh call method. */ @Test public void testCallFloat() { @@ -52,8 +38,8 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Tanh instance = new Tanh<>(tf); - Operand result = instance.call(tf.constant(input)); + Tanh instance = new Tanh<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } @@ -71,8 +57,8 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Tanh instance = new Tanh<>(tf); - Operand result = instance.call(tf.constant(input)); + Tanh instance = new Tanh<>(); + Operand result = instance.call(tf, tf.constant(input)); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java index 1f80388e88f..259d6a963b5 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java @@ -35,8 +35,8 @@ public void testCall() { for (AtomicInteger i = new AtomicInteger(); i.get() < testValues.length; i.getAndIncrement()) { - MaxNorm instance = new MaxNorm(tf, testValues[i.get()]); - Operand result = instance.call(weights); + MaxNorm instance = new MaxNorm(testValues[i.get()]); + Operand result = instance.call(tf, weights); session.evaluate(result, v -> v.floatValue() <= testValues[i.get()]); } } @@ -47,13 +47,13 @@ public void testCall1() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MaxNorm instance = new MaxNorm(tf, 2.0); + MaxNorm instance = new MaxNorm(2.0); Operand weights = tf.constant( new float[][] { {0, 1, 3, 3}, {0, 0, 0, 3}, {0, 0, 0, 3}, }); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); float[] expected = { 0, 1, 2, 1.1547005f, 0, 0, 0, 1.1547005f, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java index 8c2c3a54ff9..8b4c4007096 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java @@ -39,8 +39,8 @@ public void testCall() { for (AtomicInteger i = new AtomicInteger(); i.get() < testValues.length; i.getAndIncrement()) { - MinMaxNorm instance = new MinMaxNorm(tf, testValues[i.get()], testValues[i.get()] * 2); - Operand result = instance.call(weights); + MinMaxNorm instance = new MinMaxNorm(testValues[i.get()], testValues[i.get()] * 2); + Operand result = instance.call(tf, weights); if (tfMode == TestSession.Mode.EAGER) evaluate(session, result.asTensor(), testValues[i.get()]); else diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java index 6a6fdc13536..1a24c188860 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java @@ -17,8 +17,8 @@ public void testTFloat32() { Ops tf = session.getTF(); float[][] array = {{-1, 2, -3, 4}, {-10, 11, 12, -13}}; Operand weights = tf.constant(array); - NonNeg instance = new NonNeg(tf); - Operand result = instance.call(weights); + NonNeg instance = new NonNeg(); + Operand result = instance.call(tf, weights); float[] expected = {0, 2, 0, 4, 0, 11, 12, 0}; session.evaluate(expected, result); } @@ -31,8 +31,8 @@ public void testTFloat64() { Ops tf = session.getTF(); final double[][] array = {{-1, 2, -3, 4}, {-10, 11, 12, -13}}; Operand weights = tf.constant(array); - NonNeg instance = new NonNeg(tf); - Operand result = instance.call(weights); + NonNeg instance = new NonNeg(); + Operand result = instance.call(tf, weights); double[] expected = {0, 2, 0, 4, 0, 11, 12, 0}; session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java index 6437ebcd760..9c784b7f31e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java @@ -28,8 +28,8 @@ public void testTFloat32() { }; Operand weights = tf.constant(array); - UnitNorm instance = new UnitNorm(tf, 1); - Operand result = instance.call(weights); + UnitNorm instance = new UnitNorm(1); + Operand result = instance.call(tf, weights); Operand expected = tf.constant(expectedArray); session.evaluate(expected, result); } @@ -50,9 +50,9 @@ public void testCallTFloat64() { {{0.72920675, 0.40984813, 0.55712338}, {0.68429305, 0.91215323, 0.83042956}}, {{0.97694125, 0.99972269, 0.13576831}, {0.21350717, 0.02353181, 0.99074035}} }; - UnitNorm instance = new UnitNorm(tf, 1); + UnitNorm instance = new UnitNorm(1); Operand weights = tf.constant(array); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); Operand expected = tf.constant(expectedArray); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java index 4e81e0620e6..9291e5f83ef 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java @@ -14,12 +14,18 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; @@ -29,20 +35,6 @@ public class ConstantTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ConstantTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Constant. */ @Test public void testCallUInt() { @@ -51,8 +43,9 @@ public void testCallUInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 0xf); - Operand operand = instance.call(tf.constant(shape), TUint8.class); + Constant instance = new Constant<>(0xf); + + Operand operand = instance.call(tf, tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -67,8 +60,9 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 0xf); - Operand operand = instance.call(tf.constant(shape), TInt32.class); + Constant instance = new Constant<>(0xf); + + Operand operand = instance.call(tf, tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -83,8 +77,9 @@ public void testCallLong() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 0xffL); - Operand operand = instance.call(tf.constant(shape), TInt64.class); + Constant instance = new Constant<>(0xffL); + + Operand operand = instance.call(tf, tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -97,8 +92,9 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 12.F); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Constant instance = new Constant<>(12.F); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -112,8 +108,9 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 11.); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Constant instance = new Constant<>(11.); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -129,8 +126,9 @@ public void testCallString() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 22); - instance.call(tf.constant(shape), TString.class); + Constant instance = new Constant<>(22); + + instance.call(tf, tf.constant(shape), TString.class); fail("IllegalArgumentException should have been thrown for TString"); } }); @@ -145,8 +143,9 @@ public void testCallBool() { Shape shape = Shape.of(2, 2); Boolean[] expected = {true, true, true, true}; - Constant instance = new Constant<>(tf, true); - Operand operand = instance.call(tf.constant(shape), TBool.class); + Constant instance = new Constant<>(true); + + Operand operand = instance.call(tf, tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -158,9 +157,10 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 11.); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Constant instance = new Constant<>(11.); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java index e9769806928..166011c3b64 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.VarianceScaling.Distribution; import org.tensorflow.framework.utils.TestSession; @@ -29,20 +29,6 @@ public class GlorotTest { private static final long SEED = 1000L; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public GlorotTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Glorot. */ @Test public void testCallNormalFloat() { @@ -51,9 +37,9 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Glorot instance = new Glorot<>(Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,8 +54,9 @@ public void testCallNormalDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,8 +69,9 @@ public void testCallUniformFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Glorot instance = new Glorot<>(Distribution.UNIFORM, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -97,8 +85,9 @@ public void testCallUniformDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.UNIFORM, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -109,9 +98,10 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -122,9 +112,10 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.UNIFORM, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -135,10 +126,10 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = - new Glorot<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Glorot instance = new Glorot<>(Distribution.NORMAL, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java index 8953fa3005e..7b183358f85 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.VarianceScaling.Distribution; import org.tensorflow.framework.utils.TestSession; @@ -29,20 +29,6 @@ public class HeTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; int counter; - public HeTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class He. */ @Test public void testCallNormalFloat() { @@ -51,8 +37,9 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + He instance = new He<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -66,8 +53,9 @@ public void testCallNormalDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -80,8 +68,9 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + He instance = new He<>(Distribution.UNIFORM, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -95,8 +84,9 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.UNIFORM, SEED); + + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -107,9 +97,10 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.TRUNCATED_NORMAL, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -120,9 +111,10 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.UNIFORM, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -133,9 +125,10 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + He instance = new He<>(Distribution.NORMAL, SEED); + + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java index 6eee5473937..3f5c6cdb363 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java @@ -14,37 +14,19 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; /** Test the Identity initializer */ public class IdentityTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public IdentityTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Constant. */ @Test public void testCallFloat() { @@ -64,8 +46,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Identity instance = new Identity<>(tf, 2.); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Identity instance = new Identity<>(2.); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -90,8 +72,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Identity instance = new Identity<>(tf, 2.); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Identity instance = new Identity<>(2.); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -103,9 +85,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Identity instance = new Identity<>(tf, 2.); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Identity instance = new Identity<>(2.); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java index 336850a5549..8858bac13dd 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.VarianceScaling.Distribution; import org.tensorflow.framework.utils.TestSession; @@ -29,20 +29,6 @@ public class LeCunTest { private static final long SEED = 1000L; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public LeCunTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class LeCun. */ @Test public void testCallNormalFloat() { @@ -51,8 +37,8 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + LeCun instance = new LeCun<>(Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -66,8 +52,8 @@ public void testCallNormalDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -80,8 +66,8 @@ public void testCallUniformFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + LeCun instance = new LeCun<>(Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -95,8 +81,8 @@ public void testCallUniformDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -107,9 +93,9 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.TRUNCATED_NORMAL, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -120,9 +106,9 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -133,9 +119,9 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + LeCun instance = new LeCun<>(Distribution.NORMAL, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java index 053ba5dd7ff..4872ce7ad8e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java @@ -14,12 +14,18 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; @@ -29,20 +35,6 @@ public class OnesTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public OnesTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Ones. */ @Test public void testCallUInt() { @@ -51,8 +43,8 @@ public void testCallUInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TUint8.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -65,8 +57,8 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt32.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -79,8 +71,8 @@ public void testCallLong() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt64.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -93,8 +85,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -108,8 +100,8 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -125,8 +117,8 @@ public void testCallString() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - instance.call(tf.constant(shape), TString.class); + Ones instance = new Ones<>(); + instance.call(tf, tf.constant(shape), TString.class); fail("IllegalArgumentException should have been thrown for TString"); } }); @@ -140,8 +132,8 @@ public void testCallBool() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TBool.class); + Ones instance = new Ones<>(); + Operand operand = instance.call(tf, tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -153,9 +145,23 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Ones instance = new Ones<>(); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); + session.evaluate(operand1, operand2); + } + } + + @Test + public void testFunctional() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(2, 2); + + Initializer instance = (ltf, dims, type) -> ltf.ones(dims, type); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java index 22b89d9177c..c933e669dfd 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java @@ -14,17 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; /** Test the Orthogonal initializer */ public class OrthogonalTest { @@ -33,20 +29,6 @@ public class OrthogonalTest { private static final double GAIN_VALUE = 1.0; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public OrthogonalTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Orthogonal. */ @Test public void testCallFloat() { @@ -156,8 +138,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Orthogonal instance = new Orthogonal<>(GAIN_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -271,8 +253,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Orthogonal instance = new Orthogonal<>(GAIN_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -284,9 +266,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Orthogonal instance = new Orthogonal<>(GAIN_VALUE, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java index 3b2b3bdb243..dada058af42 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -30,20 +30,6 @@ public class RandomNormalTest { private static final double STDDEV_VALUE = 3.0; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public RandomNormalTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class RandomNormal. */ @Test public void testCalltestSoftmaxFloat() { @@ -52,9 +38,8 @@ public void testCalltestSoftmaxFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = - new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + RandomNormal instance = new RandomNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,9 +53,8 @@ public void testCalltestSoftmaxDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = - new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + RandomNormal instance = new RandomNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,10 +66,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = - new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + RandomNormal instance = new RandomNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java index 23e26083a9b..1a1b3f755b7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -31,20 +31,6 @@ public class RandomUniformTest { private static final double MAX_VALUE = 10.0; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public RandomUniformTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class RandomUniform. */ @Test public void testCallInt() { @@ -53,9 +39,8 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TInt32.class); + RandomUniform instance = new RandomUniform<>(MIN_VALUE, MAX_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -68,9 +53,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + RandomUniform instance = new RandomUniform<>(MIN_VALUE, MAX_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -84,9 +68,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + RandomUniform instance = new RandomUniform<>(MIN_VALUE, MAX_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -98,10 +81,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + RandomUniform instance = new RandomUniform<>(MIN_VALUE, MAX_VALUE, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java index 96bf915e199..6ea19fde349 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -30,20 +30,6 @@ public class TruncatedNormalTest { private static final double STDDEV_VALUE = 3.0; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public TruncatedNormalTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class TruncatedNormal. */ @Test public void testCallFloat() { @@ -52,9 +38,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = - new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + TruncatedNormal instance = new TruncatedNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,9 +53,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = - new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + TruncatedNormal instance = new TruncatedNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,10 +66,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = - new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + TruncatedNormal instance = new TruncatedNormal<>(MEAN_VALUE, STDDEV_VALUE, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java index 159affb07e2..56aa95ecf73 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -28,20 +28,6 @@ public class VarianceScalingTest { private static final long SEED = 1000L; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public VarianceScalingTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class VarianceScaling. */ @Test public void testCallFloat1FanInTruncatedNormal() { @@ -52,12 +38,11 @@ public void testCallFloat1FanInTruncatedNormal() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -73,12 +58,11 @@ public void testCallDouble1FanInTruncatedNormal() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -93,12 +77,8 @@ public void testCallFloat1FanInNormal() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, - 1.0, - VarianceScaling.Mode.FAN_IN, - VarianceScaling.Distribution.NORMAL, - SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -114,12 +94,8 @@ public void testCalltestSoftmaxDouble1FanInNormal() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, - 1.0, - VarianceScaling.Mode.FAN_IN, - VarianceScaling.Distribution.NORMAL, - SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -134,8 +110,8 @@ public void testCalltestSoftmaxFloat1FanInUNIFORM() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -151,8 +127,8 @@ public void testCalltestSoftmaxDouble1FanInUNIFORM() { Shape shape = Shape.of(2, 2); VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -166,9 +142,9 @@ public void testReproducible1() { VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -182,13 +158,9 @@ public void testReproducible2() { VarianceScaling instance = new VarianceScaling<>( - tf, - 1.0, - VarianceScaling.Mode.FAN_IN, - VarianceScaling.Distribution.NORMAL, - SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -202,13 +174,12 @@ public void testReproducible3() { VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_OUT, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -222,9 +193,9 @@ public void testReproducible4() { VarianceScaling instance = new VarianceScaling<>( - tf, 1.0, VarianceScaling.Mode.FAN_AVG, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + 1.0, VarianceScaling.Mode.FAN_AVG, VarianceScaling.Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java index 21bad6ff360..772baee1b61 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java @@ -14,32 +14,24 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; /** Test the Zeros initializer */ public class ZerosTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - public ZerosTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class Zeros. */ @Test public void testCallUInt() { @@ -48,8 +40,8 @@ public void testCallUInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TUint8.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -62,8 +54,8 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt32.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -76,8 +68,8 @@ public void testCallLong() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt64.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -90,8 +82,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat32.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -105,8 +97,8 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat64.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -119,8 +111,8 @@ public void testCallString() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TString.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TString.class); session.evaluateString(operand, String::isEmpty); } } @@ -134,8 +126,8 @@ public void testCallBool() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TBool.class); + Zeros instance = new Zeros<>(); + Operand operand = instance.call(tf, tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -147,9 +139,23 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); + Zeros instance = new Zeros<>(); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); + session.evaluate(operand1, operand2); + } + } + + @Test + public void testFunctional() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(2, 2); + + Initializer instance = (ltf, dims, type) -> ltf.zeros(dims, type); + Operand operand1 = instance.call(tf, tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf, tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java index d2128b80839..0b662414e8f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java @@ -32,11 +32,12 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); @@ -48,9 +49,9 @@ public void testAllCorrectUnweighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new BinaryCrossentropy(tf, true); + instance = new BinaryCrossentropy(true); - loss = instance.call(yTrue, logits); + loss = instance.call(tf, yTrue, logits); testSession.evaluate(expected, loss); } } @@ -67,7 +68,8 @@ public void testInvalidPredictionsRange() { catchClass, () -> { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}; float[] predArray = {2f, 1f, -1f, 0f}; Operand yTrue = @@ -75,7 +77,7 @@ public void testInvalidPredictionsRange() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -87,12 +89,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 1f, 0f}; float[] predArray = {1f, 1f, 1f, 0f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 3.83331f; testSession.evaluate(expected, loss); @@ -105,8 +108,9 @@ public void testUnweighted() { Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); - instance = new BinaryCrossentropy(tf, true); - loss = instance.call(yTrue1, logits); + instance = new BinaryCrossentropy(true); + + loss = instance.call(tf, yTrue1, logits); expected = 33.33333f; testSession.evaluate(expected, loss); } @@ -118,13 +122,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 1f, 0f}; float[] predArray = {1f, 1f, 1f, 0f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 8.816612f; testSession.evaluate(expected, loss); @@ -137,8 +142,9 @@ public void testScalarWeighted() { Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); - instance = new BinaryCrossentropy(tf, true); - loss = instance.call(yTrue1, logits, sampleWeight); + instance = new BinaryCrossentropy(true); + + loss = instance.call(tf, yTrue1, logits, sampleWeight); expected = 76.66667f; testSession.evaluate(expected, loss); } @@ -149,7 +155,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - BinaryCrossentropy instance = new BinaryCrossentropy(tf); + BinaryCrossentropy instance = new BinaryCrossentropy(); + float[] trueArray = {1f, 0f, 1f, 0f}; float[] predArray = {1f, 1f, 1f, 0f}; float[] sampleWeightArray = {1.2f, 3.4f}; @@ -157,7 +164,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); Operand sampleWeight = tf.reshape(tf.constant(sampleWeightArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 4.59997f; testSession.evaluate(expected, loss); @@ -172,8 +179,9 @@ public void testSampleWeighted() { Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight1 = tf.constant(sampleWeightArray1); - instance = new BinaryCrossentropy(tf, true); - loss = instance.call(yTrue1, logits, sampleWeight1); + instance = new BinaryCrossentropy(true); + + loss = instance.call(tf, yTrue1, logits, sampleWeight1); expected = 100f; testSession.evaluate(expected, loss); } @@ -196,8 +204,9 @@ public void testNoReduction() { tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); BinaryCrossentropy instance = new BinaryCrossentropy( - tf, true, BinaryCrossentropy.LABEL_SMOOTHING_DEFAULT, Reduction.NONE); - Operand loss = instance.call(yTrue, logits); + true, BinaryCrossentropy.LABEL_SMOOTHING_DEFAULT, Reduction.NONE); + + Operand loss = instance.call(tf, yTrue, logits); Float[] expected = {0.f, 66.666664f}; testSession.evaluate(expected, loss); } @@ -215,8 +224,9 @@ public void testLabelSmoothing() { Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(1, 3))); - BinaryCrossentropy instance = new BinaryCrossentropy(tf, true, labelSmoothing); - Operand loss = instance.call(yTrue, logits); + BinaryCrossentropy instance = new BinaryCrossentropy(true, labelSmoothing); + + Operand loss = instance.call(tf, yTrue, logits); float expected = (100.0f + 50.0f * labelSmoothing) / 3.0f; testSession.evaluate(expected, loss); } catch (Exception expected) { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java index 13b287de3cd..3f6453b756a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java @@ -48,8 +48,9 @@ public void testAllCorrectUnweighted() { }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); - Operand loss = instance.call(yTrue, yPred); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0F; testSession.evaluate(expected, loss); @@ -62,8 +63,9 @@ public void testAllCorrectUnweighted() { yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new CategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits); + instance = new CategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits); testSession.setEpsilon(1e-3F); testSession.evaluate(0.0F, loss); } @@ -81,7 +83,8 @@ public void testInvalidPredictionsRange() { catchClass, () -> { Ops tf = testSession.getTF(); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + float[] trueArray = { 1L, 0L, 0L, 0L, 1L, 0L, @@ -97,7 +100,7 @@ public void testInvalidPredictionsRange() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -109,7 +112,8 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; float[] predArray = { .9F, .05F, .05F, @@ -118,7 +122,7 @@ public void testUnweighted() { }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.32396814F; testSession.evaluate(expected, loss); @@ -130,8 +134,9 @@ public void testUnweighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new CategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits); + instance = new CategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits); expected = 0.0573755F; testSession.evaluate(expected, loss); } @@ -158,8 +163,9 @@ public void testScalarWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.constant(2.3F); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = .7451267F; testSession.evaluate(expected, loss); @@ -171,8 +177,9 @@ public void testScalarWeighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new CategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits, sampleWeight); + instance = new CategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits, sampleWeight); expected = 0.13196386F; testSession.evaluate(expected, loss); } @@ -183,7 +190,8 @@ public void testSsampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + CategoricalCrossentropy instance = new CategoricalCrossentropy(); + float[] sampeWeightArray = {1.2F, 3.4F, 5.6F}; int[] trueArray = { 1, 0, 0, @@ -199,7 +207,7 @@ public void testSsampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampeWeightArray), tf.constant(Shape.of(3, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.0696F; testSession.evaluate(expected, loss); @@ -211,8 +219,9 @@ public void testSsampleWeighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new CategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits, sampleWeight); + instance = new CategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits, sampleWeight); expected = 0.31829F; testSession.evaluate(expected, loss); } @@ -234,9 +243,9 @@ public void testNoReduction() { Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - CategoricalCrossentropy instance = - new CategoricalCrossentropy(tf, true, 0.0F, Reduction.NONE); - Operand loss = instance.call(yTrue, logits); + CategoricalCrossentropy instance = new CategoricalCrossentropy(true, 0.0F, Reduction.NONE); + + Operand loss = instance.call(tf, yTrue, logits); Float[] expected = {0.001822F, 0.000459F, 0.169846F}; testSession.evaluate(expected, loss); } @@ -254,8 +263,9 @@ public void testLabelSmoothing() { Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(1, 3))); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf, true, labelSmoothing); - Operand loss = instance.call(yTrue, logits); + CategoricalCrossentropy instance = new CategoricalCrossentropy(true, labelSmoothing); + + Operand loss = instance.call(tf, yTrue, logits); float expected = 400.0F * labelSmoothing / 3.0F; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java index b0d0442b3c7..d00f5374d61 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java @@ -31,12 +31,13 @@ public void testReductionNone() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf, Reduction.NONE); + CategoricalHinge instance = new CategoricalHinge(Reduction.NONE); + int[] trueArray = {1, 9, 2, -5}; float[] predArray = {4f, 8f, 12f, 8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); Float[] expected = {0.0f, 65.0f}; testSession.evaluate(expected, loss); } @@ -48,12 +49,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5}; float[] predArray = {4f, 8f, 12f, 8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 32.5f; testSession.evaluate(expected, loss); } @@ -65,17 +67,18 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 83.95f; testSession.evaluate(expected, loss); - Operand loss2 = instance.call(yTrue, yPred, sampleWeight); + Operand loss2 = instance.call(tf, yTrue, yPred, sampleWeight); testSession.evaluate(loss, loss2); } } @@ -85,7 +88,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] weightsNp = {1.2f, 3.4f}; @@ -93,7 +97,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(weightsNp), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 124.1f; testSession.evaluate(expected, loss); } @@ -104,13 +108,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -121,7 +126,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); + CategoricalHinge instance = new CategoricalHinge(); + int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] weightsNp = {3, 6, 5, 0, 4, 2}; @@ -130,7 +136,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(weightsNp), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 4.0f; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java index 8350d1403ed..2f21929a969 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java @@ -33,11 +33,12 @@ public void testReductionNone() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - CosineSimilarity instance = new CosineSimilarity(tf, Reduction.NONE); + CosineSimilarity instance = new CosineSimilarity(Reduction.NONE); + Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); Float[] expected = {-0.720488f, 0.3460499f}; testSession.evaluate(expected, loss); } @@ -52,11 +53,12 @@ public void testUnweighted() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] expectedLoss = {0.720488f, -0.3460499f}; - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = -mean(expectedLoss); testSession.evaluate(expected, loss); } @@ -71,12 +73,13 @@ public void testScalarWeighted() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] expectedLoss = {0.720488f, -0.3460499f}; - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -mean(mul(expectedLoss, 2.3f)); testSession.evaluate(expected, loss); } @@ -90,14 +93,15 @@ public void testSampleWeighted() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] expectedLoss = {0.720488f, -0.3460499f}; - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + float[] weightsArray = {1.2f, 3.4f}; Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); Operand sampleWeight = tf.reshape(tf.constant(weightsArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -mean(mul(expectedLoss, weightsArray)); testSession.evaluate(expected, loss); } @@ -108,14 +112,15 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); Operand sampleWeight = tf.constant(0f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -128,14 +133,15 @@ public void testTimestepWeighted() { Ops tf = testSession.getTF(); float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - CosineSimilarity instance = new CosineSimilarity(tf); + CosineSimilarity instance = new CosineSimilarity(); + Shape shape = Shape.of(2, 3, 1); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); float[] weightsArray = {3, 6, 5, 0, 4, 2}; Operand sampleWeight = tf.reshape(tf.constant(weightsArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -2.0f; testSession.evaluate(expected, loss); } @@ -149,11 +155,12 @@ public void testAxis() { float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] expectedLoss = {0.720488f, -0.3460499f}; - CosineSimilarity instance = new CosineSimilarity(tf, 1); + CosineSimilarity instance = new CosineSimilarity(1); + Shape shape = Shape.of(2, 3); Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = -mean(expectedLoss); testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java index 4770511207e..d5fe846c82e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java @@ -33,12 +33,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.50625f; testSession.evaluate(expected, loss); } @@ -56,14 +57,15 @@ public void testInvalidLabelValue() { catchClass, () -> { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] trueArray = {2f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -75,13 +77,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.164375f; testSession.evaluate(expected, loss); @@ -94,7 +97,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] sampleArray = {1.2f, 3.4f}; float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; @@ -102,7 +106,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.06125f; testSession.evaluate(expected, loss); } @@ -113,13 +117,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); + Hinge instance = new Hinge(); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -130,7 +135,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf, Reduction.AUTO); + Hinge instance = new Hinge(Reduction.AUTO); + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f, 1f, 3f}; float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; @@ -140,7 +146,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 2.0125f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java index d1751f223a1..86a71e5ecbb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java @@ -32,8 +32,9 @@ public void testAllCorrect() { float[] trueArray = {.9f, .2f, .2f, .8f, .4f, .6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); - Operand loss = instance.call(yTrue, yTrue); + Huber instance = new Huber(); + + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -50,8 +51,9 @@ public void testUnweighted() { float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); - Operand loss = instance.call(yTrue, yPred); + Huber instance = new Huber(); + + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.10416666666666669f; testSession.evaluate(expected, loss); } @@ -67,9 +69,10 @@ public void testScalarWeighted() { float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); + Huber instance = new Huber(); + Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.23958333333333337f; testSession.evaluate(expected, loss); @@ -87,10 +90,11 @@ public void testSampleWeighted() { float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); + Huber instance = new Huber(); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.22766666666666668f; testSession.evaluate(expected, loss); } @@ -105,9 +109,10 @@ public void testZeroWeighted() { float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Huber instance = new Huber(tf); + Huber instance = new Huber(); + Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -125,10 +130,11 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); - Huber instance = new Huber(tf); + Huber instance = new Huber(); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = .4025f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java index d57b61b18dd..1d7ee87b920 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java @@ -30,12 +30,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); + KLDivergence instance = new KLDivergence(); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.5960738398643668f; testSession.evaluate(expected, loss); } @@ -47,13 +48,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); + KLDivergence instance = new KLDivergence(); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.3709698316880434f; testSession.evaluate(expected, loss); } @@ -64,7 +66,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); + KLDivergence instance = new KLDivergence(); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; float[] sampleArray = {1.2f, 3.4f}; @@ -72,7 +75,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 2.0075711736936492f; testSession.evaluate(expected, loss); } @@ -83,13 +86,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); + KLDivergence instance = new KLDivergence(); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -100,7 +104,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf, Reduction.AUTO); + KLDivergence instance = new KLDivergence(Reduction.AUTO); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -110,7 +115,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.2495994912084345f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java index c4347b3fccb..ce6782cee3b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java @@ -30,12 +30,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); + LogCosh instance = new LogCosh(); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 4.829245330860459f; testSession.evaluate(expected, loss); } @@ -47,13 +48,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); + LogCosh instance = new LogCosh(); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 11.107264260979056f; testSession.evaluate(expected, loss); } @@ -64,7 +66,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); + LogCosh instance = new LogCosh(); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -72,7 +75,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 12.001114667519486f; testSession.evaluate(expected, loss); } @@ -83,13 +86,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); + LogCosh instance = new LogCosh(); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -100,7 +104,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf, Reduction.AUTO); + LogCosh instance = new LogCosh(Reduction.AUTO); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -110,7 +115,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 11.653484271934046f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java index 3498c6d53aa..cbcb2c37391 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java @@ -31,10 +31,11 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -46,12 +47,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 5.5f; testSession.evaluate(expected, loss); } @@ -63,13 +65,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 12.65f; testSession.evaluate(expected, loss); } @@ -80,7 +83,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -88,7 +92,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 81.4f / 6f; testSession.evaluate(expected, loss); } @@ -99,13 +103,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -116,7 +121,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf, Reduction.AUTO); + MeanAbsoluteError instance = new MeanAbsoluteError(Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -126,7 +132,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 83f / 6f; testSession.evaluate(expected, loss); @@ -141,7 +147,8 @@ public void testInvalidSampleWeight() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf); + MeanAbsoluteError instance = new MeanAbsoluteError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f}; @@ -151,7 +158,7 @@ public void testInvalidSampleWeight() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 83f / 6f; testSession.evaluate(expected, loss); } @@ -163,13 +170,14 @@ public void testNoReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf, Reduction.NONE); + MeanAbsoluteError instance = new MeanAbsoluteError(Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {10.733333f, 14.566667f}; testSession.evaluate(expected, loss); } @@ -180,13 +188,14 @@ public void testSumReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsoluteError instance = new MeanAbsoluteError(tf, Reduction.SUM); + MeanAbsoluteError instance = new MeanAbsoluteError(Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {25.29999f}; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java index 7816a8a288a..b521f2f5644 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java @@ -30,10 +30,11 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -45,12 +46,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 211.85184f; testSession.evaluate(expected, loss); } @@ -62,13 +64,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 487.25922f; testSession.evaluate(expected, loss); } @@ -79,7 +82,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -87,7 +91,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 422.8889f; testSession.evaluate(expected, loss); } @@ -98,13 +102,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -115,7 +120,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf, Reduction.AUTO); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -125,7 +131,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 694.4445f; testSession.evaluate(expected, loss); } @@ -136,13 +142,14 @@ public void testNoReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf, Reduction.NONE); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {621.8518f, 352.66666f}; testSession.evaluate(expected, loss); } @@ -153,13 +160,14 @@ public void testSumReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf, Reduction.SUM); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 974.51843f; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java index 1a971f0492b..e9fd0d7e349 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java @@ -31,10 +31,11 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -46,12 +47,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 49.5f; testSession.evaluate(expected, loss); } @@ -63,13 +65,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 113.85f; testSession.evaluate(expected, loss); } @@ -80,7 +83,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -88,7 +92,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 127.96667f; testSession.evaluate(expected, loss); } @@ -99,13 +103,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -116,7 +121,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf, Reduction.AUTO); + MeanSquaredError instance = new MeanSquaredError(Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -126,7 +132,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 97.833336f; testSession.evaluate(expected, loss); @@ -141,7 +147,8 @@ public void testInvalidSampleWeight() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf); + MeanSquaredError instance = new MeanSquaredError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f}; @@ -151,7 +158,7 @@ public void testInvalidSampleWeight() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 173.25f; testSession.evaluate(expected, loss); } @@ -163,13 +170,14 @@ public void testNoReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf, Reduction.NONE); + MeanSquaredError instance = new MeanSquaredError(Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {84.333336f, 143.36665f}; testSession.evaluate(expected, loss); } @@ -180,13 +188,14 @@ public void testSumReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredError instance = new MeanSquaredError(tf, Reduction.SUM); + MeanSquaredError instance = new MeanSquaredError(Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {227.69998f}; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java index 558f9c84659..0c6d411c53f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java @@ -31,10 +31,11 @@ public void testAllCorrectUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); + Operand loss = instance.call(tf, yTrue, yTrue); float expected = 0.0f; testSession.evaluate(expected, loss); } @@ -46,12 +47,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 1.4370421f; testSession.evaluate(expected, loss); } @@ -63,13 +65,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 3.3051968f; testSession.evaluate(expected, loss); } @@ -80,7 +83,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -88,7 +92,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 3.7856376f; testSession.evaluate(expected, loss); } @@ -99,13 +103,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -116,7 +121,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.AUTO); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -126,7 +132,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 2.647374f; testSession.evaluate(expected, loss); @@ -141,7 +147,8 @@ public void testInvalidSampleWeight() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f}; @@ -151,7 +158,7 @@ public void testInvalidSampleWeight() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 83f / 6f; testSession.evaluate(expected, loss); } @@ -163,13 +170,14 @@ public void testNoReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.NONE); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {2.3006392f, 4.3097544f}; testSession.evaluate(expected, loss); } @@ -180,13 +188,14 @@ public void testSumReduction() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.SUM); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); Float[] expected = {6.6103935f}; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java index 55c59ca5ac6..c354c83bfe2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java @@ -30,12 +30,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); + Poisson instance = new Poisson(); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = -3.306581945521002f; testSession.evaluate(expected, loss); } @@ -47,13 +48,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); + Poisson instance = new Poisson(); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -7.605138474698304f; testSession.evaluate(expected, loss); } @@ -64,7 +66,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); + Poisson instance = new Poisson(); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {1.2f, 3.4f}; @@ -72,7 +75,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -6.147338926788071f; testSession.evaluate(expected, loss); } @@ -83,13 +86,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); + Poisson instance = new Poisson(); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -100,7 +104,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf, Reduction.AUTO); + Poisson instance = new Poisson(Reduction.AUTO); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; @@ -110,7 +115,7 @@ public void testTimestepWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = -12.263126013890561f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java index a6a0ff35c78..113b89b82ff 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java @@ -44,8 +44,9 @@ public void testAllCorrectUnweighted() { }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); - Operand loss = instance.call(yTrue, yPred); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.0f; testSession.evaluate(expected, loss); @@ -57,8 +58,9 @@ public void testAllCorrectUnweighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new SparseCategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits); + instance = new SparseCategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits); testSession.evaluate(0.0f, loss); } } @@ -75,7 +77,8 @@ public void testInvalidPredictionsRange() { catchClass, () -> { Ops tf = testSession.getTF(); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + int[] trueArray = {0, 1, 2}; float[] predArray = { 1.9f, .05f, .05f, @@ -86,7 +89,7 @@ public void testInvalidPredictionsRange() { tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -98,7 +101,8 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + int[] trueArray = {0, 1, 2}; float[] predArray = { .9f, .05f, .05f, @@ -107,7 +111,7 @@ public void testUnweighted() { }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.32396814f; testSession.evaluate(expected, loss); @@ -119,8 +123,9 @@ public void testUnweighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new SparseCategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits); + instance = new SparseCategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits); expected = 0.05737559f; testSession.evaluate(expected, loss); } @@ -143,8 +148,9 @@ public void testScalarWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.constant(2.3f); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = .7451267f; testSession.evaluate(expected, loss); @@ -156,8 +162,9 @@ public void testScalarWeighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new SparseCategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits, sampleWeight); + instance = new SparseCategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits, sampleWeight); expected = 0.13196386f; testSession.evaluate(expected, loss); } @@ -168,7 +175,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(); + float[] sampleWeightArray = {1.2f, 3.4f, 5.6f}; int[] trueArray = {0, 1, 2}; float[] predArray = { @@ -180,7 +188,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampleWeightArray), tf.constant(Shape.of(3, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.0696f; testSession.evaluate(expected, loss); @@ -192,8 +200,9 @@ public void testSampleWeighted() { }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); - instance = new SparseCategoricalCrossentropy(tf, true); - loss = instance.call(yTrue, logits, sampleWeight); + instance = new SparseCategoricalCrossentropy(true); + + loss = instance.call(tf, yTrue, logits, sampleWeight); expected = 0.31829f; testSession.evaluate(expected, loss); } @@ -216,8 +225,9 @@ public void testNoReduction() { Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); SparseCategoricalCrossentropy instance = - new SparseCategoricalCrossentropy(tf, true, Reduction.NONE); - Operand loss = instance.call(yTrue, logits); + new SparseCategoricalCrossentropy(true, Reduction.NONE); + + Operand loss = instance.call(tf, yTrue, logits); Float[] expected = {0.001822f, 0.000459f, 0.169846f}; testSession.evaluate(expected, loss); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java index 57a012bbe9d..979e778e4c3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java @@ -32,12 +32,13 @@ public void testUnweighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); float expected = 0.364062f; testSession.evaluate(expected, loss); } @@ -55,14 +56,15 @@ public void testInvalidLabelValue() { catchClass, () -> { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] trueArray = {0, 2, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); + Operand loss = instance.call(tf, yTrue, yPred); testSession.run(loss); }); } @@ -74,13 +76,14 @@ public void testScalarWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.8373437f; testSession.evaluate(expected, loss); } @@ -91,7 +94,8 @@ public void testSampleWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] sampleArray = {1.2f, 3.4f}; float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; @@ -99,7 +103,7 @@ public void testSampleWeighted() { Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0.7043125f; testSession.evaluate(expected, loss); } @@ -110,13 +114,14 @@ public void testZeroWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); + SquaredHinge instance = new SquaredHinge(); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 0f; testSession.evaluate(expected, loss); } @@ -127,7 +132,8 @@ public void testTimestepWeighted() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf, Reduction.AUTO); + SquaredHinge instance = new SquaredHinge(Reduction.AUTO); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand yTrue = @@ -137,7 +143,7 @@ public void testTimestepWeighted() { float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f, 1f, 3f}; Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + Operand loss = instance.call(tf, yTrue, yPred, sampleWeight); float expected = 1.54250000f; testSession.evaluate(expected, loss); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index d6786b71972..d957cfb2508 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -1,13 +1,17 @@ package org.tensorflow.framework.optimizers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.framework.initializers.Glorot; import org.tensorflow.framework.initializers.VarianceScaling; import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.op.Op; @@ -26,10 +30,8 @@ import org.tensorflow.types.family.TType; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; /** Test cases for GradientDescent Optimizer */ @@ -136,14 +138,14 @@ public void testDeterminism() { Ops tf = Ops.create(g); Glorot initializer = - new Glorot<>(tf, VarianceScaling.Distribution.TRUNCATED_NORMAL, 1L); + new Glorot<>(VarianceScaling.Distribution.TRUNCATED_NORMAL, 1L); // Inputs Placeholder input = tf.withName("input").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 20))); // Fully connected layer Variable fcWeights = - tf.variable(initializer.call(tf.array(20L, 200L), TFloat32.class)); + tf.variable(initializer.call(tf, tf.array(20L, 200L), TFloat32.class)); fcWeightName = fcWeights.op().name(); Variable fcBiases = tf.variable(tf.fill(tf.array(200), tf.constant(0.1f))); fcBiasName = fcBiases.op().name(); @@ -151,13 +153,13 @@ public void testDeterminism() { // Output layer Variable outputWeights = - tf.variable(initializer.call(tf.array(200L, 2L), TFloat32.class)); + tf.variable(initializer.call(tf, tf.array(200L, 2L), TFloat32.class)); outputWeightName = outputWeights.op().name(); Variable outputBiases = tf.variable(tf.fill(tf.array(2L), tf.constant(0.1f))); outputBiasName = outputBiases.op().name(); Add output = tf.math.add(tf.linalg.matMul(relu, outputWeights), outputBiases); - // Loss + // AbstractLoss Placeholder placeholder = tf.withName("output").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 2))); Mean loss = @@ -205,12 +207,15 @@ public void testDeterminism() { .fetch(outputBiasName) .run()); - TFloat32 lossVal = (TFloat32) s.runner() - .addTarget(trainName) - .feed("input", dataTensor) - .feed("output", targetTensor) - .fetch(lossName) - .run().get(0); + TFloat32 lossVal = + (TFloat32) + s.runner() + .addTarget(trainName) + .feed("input", dataTensor) + .feed("output", targetTensor) + .fetch(lossName) + .run() + .get(0); initialLoss[i] = lossVal.getFloat(); lossVal.close(); @@ -222,12 +227,15 @@ public void testDeterminism() { .fetch(outputBiasName) .run()); - lossVal = (TFloat32) s.runner() - .addTarget(trainName) - .feed("input", dataTensor) - .feed("output", targetTensor) - .fetch(lossName) - .run().get(0); + lossVal = + (TFloat32) + s.runner() + .addTarget(trainName) + .feed("input", dataTensor) + .feed("output", targetTensor) + .fetch(lossName) + .run() + .get(0); postTrainingLoss[i] = lossVal.getFloat(); lossVal.close(); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java index 181ae367f07..a4b98c002cb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -17,25 +17,25 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.2f, 0.3f); + L1L2 instance = new L1L2(0.2f, 0.3f); assertEquals(0.2f, instance.getL1()); assertEquals(0.3f, instance.getL2()); - instance = new L1L2(tf, 0, 0); + instance = new L1L2(0, 0); assertEquals(0.f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2(tf, 0.5f, 0); + instance = new L1L2(0.5f, 0); assertEquals(0.5f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2(tf, 0, 0.5f); + instance = new L1L2(0, 0.5f); assertEquals(0.f, instance.getL1()); assertEquals(0.5f, instance.getL2()); - instance = new L1L2(tf); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); + instance = new L1L2(); + assertEquals(AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); } } @@ -44,8 +44,8 @@ public void testCallDefaultsConstant() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf); - Operand result = instance.call(tf.constant(555f)); + L1L2 instance = new L1L2(); + Operand result = instance.call(tf, tf.constant(555f)); session.evaluate(3085.8f, result); } } @@ -55,10 +55,10 @@ public void testCallL1L2_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0, 0); + L1L2 instance = new L1L2(0, 0); Operand weights = tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); session.evaluate(0, result); } } @@ -68,10 +68,10 @@ public void testCallL1L2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.01f, 0.02f); + L1L2 instance = new L1L2(0.01f, 0.02f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); float expected = regularizeL1L2(w, 0.01f, 0.02f); session.setEpsilon(.09f); session.evaluate(expected, result); @@ -83,10 +83,10 @@ public void testCallL1L2TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.01f, 0.02f); + L1L2 instance = new L1L2(0.01f, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); double expected = regularizeL1L2(w, 0.01f, 0.02f); session.setEpsilon(.09f); session.evaluate(expected, result); @@ -98,10 +98,10 @@ public void testCallL2_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.01f, 0); + L1L2 instance = new L1L2(0.01f, 0); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); float expected = regularizeL1(w, 0.01f); session.evaluate(expected, result); } @@ -112,10 +112,10 @@ public void testCallL1_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0, 0.02f); + L1L2 instance = new L1L2(0, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); double expected = regularizeL2(w, 0.02f); session.setEpsilon(.01f); session.evaluate(expected, result); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java index 0e42a257816..f7d540fb8e1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java @@ -17,16 +17,16 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1(tf, 0.2f); + L1 instance = new L1(0.2f); assertEquals(0.2f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1(tf, 0f); + instance = new L1(0f); assertEquals(0.f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1(tf); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + instance = new L1(); + assertEquals(AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); assertEquals(0.f, instance.getL2()); } } @@ -36,10 +36,10 @@ public void testCallL10() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1(tf, 0.0f); + L1 instance = new L1(0.0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); session.evaluate(0f, result); } } @@ -49,11 +49,11 @@ public void testCallL1TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1(tf); + L1 instance = new L1(); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = regularizeL1(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + Operand result = instance.call(tf, weights); + float expected = regularizeL1(w, AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY); session.evaluate(expected, result); } } @@ -63,10 +63,10 @@ public void testCallL1TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1(tf, 0.02f); + L1 instance = new L1(0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); double expected = regularizeL1(w, 0.02f); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java index aba036ee306..4579ccaf551 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java @@ -17,16 +17,16 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2(tf, 0.2f); + L2 instance = new L2(0.2f); assertEquals(0.2f, instance.getL2()); assertEquals(0.f, instance.getL1()); - instance = new L2(tf, 0f); + instance = new L2(0f); assertEquals(0.f, instance.getL2()); assertEquals(0.f, instance.getL1()); - L2 instance64 = new L2(tf); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2()); + L2 instance64 = new L2(); + assertEquals(AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2()); assertEquals(0.f, instance64.getL1()); } } @@ -36,10 +36,10 @@ public void testCallL20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2(tf, 0.0f); + L2 instance = new L2(0.0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); session.evaluate(0, result); } } @@ -49,11 +49,11 @@ public void testCallL2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2(tf); + L2 instance = new L2(); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = regularizeL2(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + Operand result = instance.call(tf, weights); + float expected = regularizeL2(w, AbstractRegularizer.DEFAULT_REGULARIZATION_PENALTY); session.setEpsilon(.01f); session.evaluate(expected, result); } @@ -64,10 +64,10 @@ public void testCallL2TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2(tf, 0.02f); + L2 instance = new L2(0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); - Operand result = instance.call(weights); + Operand result = instance.call(tf, weights); double expected = regularizeL2(w, 0.02f); session.setEpsilon(.01f); session.evaluate(expected, result); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java index fe2624cec3d..6918f631e6a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java @@ -14,13 +14,13 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 regularizer = new L1L2(tf, 0.01f, 0f); + L1L2 regularizer = new L1L2(0.01f, 0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); - Operand regularizerResult = regularizer.call(weights); - RegularizerLoss lossInstance = new RegularizerLoss(tf, regularizer); + Operand regularizerResult = regularizer.call(tf, weights); + RegularizerLoss lossInstance = new RegularizerLoss(regularizer); - Operand loss = lossInstance.call(null, null, weights); + Operand loss = lossInstance.call(tf, null, null, weights); session.evaluate(regularizerResult, loss); } } From fe7e8e30e15158d18645aebd9275527c7c822e4a Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 2 Jun 2021 13:51:07 -0400 Subject: [PATCH 4/5] JavaDoc fixes including Dataset --- .../tensorflow/framework/data/Dataset.java | 85 +++++++++++++++---- .../framework/data/DatasetOptional.java | 37 +++++++- .../losses/SparseCategoricalCrossentropy.java | 9 +- .../framework/regularizers/L1L2.java | 1 + 4 files changed, 107 insertions(+), 25 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java index 7ac73f616e2..e227ed5d305 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java @@ -23,16 +23,16 @@ import org.tensorflow.framework.data.impl.TakeDataset; import org.tensorflow.framework.data.impl.TensorSliceDataset; import org.tensorflow.framework.data.impl.TextLineDataset; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.family.TType; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.function.Function; -import org.tensorflow.types.family.TType; /** * Represents a potentially large list of independent elements (samples), and allows iteration and @@ -44,8 +44,19 @@ public abstract class Dataset implements Iterable>> { private List> outputTypes; private List outputShapes; + /** + * Creates a Dataset + * + * @param tf the TensorFlow Ops + * @param variant the tensor that represents the dataset. + * @param outputTypes a list of output types produced by this data set. + * @param outputShapes a list of output shapes produced by this data set. + */ public Dataset( - Ops tf, Operand variant, List> outputTypes, List outputShapes) { + Ops tf, + Operand variant, + List> outputTypes, + List outputShapes) { if (tf == null) { throw new IllegalArgumentException("Ops accessor cannot be null."); } @@ -61,6 +72,11 @@ public Dataset( this.outputShapes = outputShapes; } + /** + * Creates a Dataset that is a copy of another Dataset + * + * @param other the other Dataset + */ protected Dataset(Dataset other) { this.tf = other.tf; this.variant = other.variant; @@ -127,11 +143,12 @@ public final Dataset take(long count) { * Returns a new Dataset which maps a function across all elements from this dataset, on a single * component of each element. * - *

        For example, suppose each element is a {@code List>} with 2 components: (features, - * labels). + *

        For example, suppose each element is a {@code List>} with 2 components: + * (features, labels). * - *

        Calling {@code dataset.mapOneComponent(0, features -> tf.math.mul(features, tf.constant(2)))} will - * map the function over the `features` component of each element, multiplying each by 2. + *

        Calling {@code dataset.mapOneComponent(0, features -> tf.math.mul(features, + * tf.constant(2)))} will map the function over the `features` component of each element, + * multiplying each by 2. * * @param index The index of the component to transform. * @param mapper The function to apply to the target component. @@ -150,8 +167,8 @@ public Dataset mapOneComponent(int index, Function, Operand> mappe * Returns a new Dataset which maps a function across all elements from this dataset, on all * components of each element. * - *

        For example, suppose each element is a {@code List>} with 2 components: (features, - * labels). + *

        For example, suppose each element is a {@code List>} with 2 components: + * (features, labels). * *

        Calling {@code dataset.mapAllComponents(component -> tf.math.mul(component, * tf.constant(2)))} will map the function over the both the `features` and `labels` components of @@ -172,8 +189,8 @@ public Dataset mapAllComponents(Function, Operand> mapper) { /** * Returns a new Dataset which maps a function over all elements returned by this dataset. * - *

        For example, suppose each element is a {@code List>} with 2 components: (features, - * labels). + *

        For example, suppose each element is a {@code List>} with 2 components: + * (features, labels). * *

        Calling * @@ -261,8 +278,8 @@ public DatasetIterator makeOneShotIterator() { * @param tf Ops Accessor * @param tensors A list of {@code Operand} representing components of this dataset (e.g. * features, labels) - * @param outputTypes A list of tensor type classes representing the data type of each component of - * this dataset. + * @param outputTypes A list of tensor type classes representing the data type of each component + * of this dataset. * @return A new `Dataset` */ public static Dataset fromTensorSlices( @@ -270,37 +287,73 @@ public static Dataset fromTensorSlices( return new TensorSliceDataset(tf, tensors, outputTypes); } + /** + * Creates a TFRecordDataset from a file containing TFRecords + * + * @param tf the TensorFlow Ops + * @param filename the file name that holds the TFRecords + * @param compressionType the compresstion type for the file + * @param bufferSize the buffersize for processing the TFRecords file. + * @return a TFRecordDataset + */ public static Dataset tfRecordDataset( Ops tf, String filename, String compressionType, long bufferSize) { return new TFRecordDataset( tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize)); } + /** + * Creates a TextLineDataset from a file containing one recored per ling. + * + * @param tf the TensorFlow Ops + * @param filename the file name that holds the data records + * @param compressionType the compresstion type for the file + * @param bufferSize the buffersize for processing the records file. + * @return a TextLineDataset + */ public static Dataset textLineDataset( Ops tf, String filename, String compressionType, long bufferSize) { return new TextLineDataset( tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize)); } - /** Get the variant tensor representing this dataset. */ + /** + * Gets the variant tensor representing this dataset. + * + * @return the variant tensor representing this dataset. + */ public Operand getVariant() { return variant; } - /** Get a list of output types for each component of this dataset. */ + /** + * Gets a list of output types for each component of this dataset. + * + * @return list of output types for each component of this dataset. + */ public List> getOutputTypes() { return this.outputTypes; } - /** Get a list of shapes for each component of this dataset. */ + /** + * Gets a list of shapes for each component of this dataset. + * + * @return a list of shapes for each component of this dataset. + */ public List getOutputShapes() { return this.outputShapes; } + /** + * Gets the TensorFlow Ops instance for this dataset + * + * @return the TensorFlow Ops instance for this dataset + */ public Ops getOpsInstance() { return this.tf; } + /** {@inheritDoc} */ @Override public String toString() { return "Dataset{" diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java index 6617c33eaf7..f1df7b78e94 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java @@ -31,6 +31,10 @@ public class DatasetOptional { protected Ops tf; + /** + * Gets the optional variant for this Dataset + * @return the optional variant for this Dataset + */ public Operand getOptionalVariant() { return optionalVariant; } @@ -39,6 +43,13 @@ public Operand getOptionalVariant() { private List> outputTypes; private List outputShapes; + /** + * Creates a DatasetOptional dataset + * @param tf the TensorFlow Ops + * @param optionalVariant the tensor that represents the dataset. + * @param outputTypes a list of output types produced by this data set. + * @param outputShapes a list of output shapes produced by this data set. + */ public DatasetOptional( Ops tf, Operand optionalVariant, List> outputTypes, List outputShapes) { this.tf = tf; @@ -47,6 +58,11 @@ public DatasetOptional( this.outputShapes = outputShapes; } + /** + * Creates a Dataset that is a copy of another Dataset + * + * @param other the other Dataset + */ protected DatasetOptional(DatasetOptional other) { this.tf = other.tf; this.optionalVariant = other.optionalVariant; @@ -56,12 +72,16 @@ protected DatasetOptional(DatasetOptional other) { - /** Whether this optional has a value. */ + /** Gets the indicator of whether this optional has a value. + * @return the indicator of whether this optional has a value. + */ public Operand hasValue() { return tf.data.optionalHasValue(optionalVariant).hasValue(); } - /** Returns the value of the dataset element represented by this optional, if it exists. */ + /** Returns the value of the dataset element represented by this optional, if it exists. + * @return the value of the dataset element represented by this optional, if it exists. + */ public List> getValue() { List> components = new ArrayList<>(); tf.data @@ -72,6 +92,14 @@ public List> getValue() { return components; } + /** + * Creates a DatasetOptional from components. + * @param tf the TensorFlow Ops + * @param components the components that constitute the DatasetOptional + * @param outputTypes a list of output types produced by this data set. + * @param outputShapes a list of output shapes produced by this data set. + * @return a a DatasetOptional + */ public static DatasetOptional fromComponents( Ops tf, List> components, @@ -81,6 +109,11 @@ public static DatasetOptional fromComponents( return new DatasetOptional(tf, optionalVariant, outputTypes, outputShapes); } + /** + * Gets the TensorFlow Ops instance for this dataset + * + * @return the TensorFlow Ops instance for this dataset + */ public Ops getOpsInstance() { return tf; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 291a91894b0..13b4eb8225a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -84,7 +84,7 @@ public class SparseCategoricalCrossentropy extends AbstractLoss { * name, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and fromLogits={@link * #FROM_LOGITS_DEFAULT}. * - * @param tf the TensorFlow Ops + */ public SparseCategoricalCrossentropy() { this(null, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); @@ -104,7 +104,6 @@ public SparseCategoricalCrossentropy(String name) { * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss * name, with Reduction.AUTO and fromLogits={@link #FROM_LOGITS_DEFAULT}. * - * @param tf the TensorFlow Ops * @param reduction Type of Reduction to apply to loss. */ public SparseCategoricalCrossentropy(Reduction reduction) { @@ -115,7 +114,6 @@ public SparseCategoricalCrossentropy(Reduction reduction) { * Creates a SparseCategoricalCrossentropy loss with Reduction.AUTO and fromLogits={@link * #FROM_LOGITS_DEFAULT}. * - * @param tf the TensorFlow Ops * @param name the name of this loss function * @param reduction Type of Reduction to apply to loss. */ @@ -127,7 +125,6 @@ public SparseCategoricalCrossentropy(String name, Reduction reduction) { * Creates a SparseCategoricalCrossentropy using a AbstractLoss Reduction of {@link * AbstractLoss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * - * @param tf the TensorFlow Ops * @param name the name of this loss function * @param fromLogits Whether to interpret predictions as a tensor of logit values */ @@ -140,7 +137,6 @@ public SparseCategoricalCrossentropy(String name, boolean fromLogits) { * name, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT} and fromLogits={@link * #FROM_LOGITS_DEFAULT}. * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values */ public SparseCategoricalCrossentropy(boolean fromLogits) { @@ -151,7 +147,6 @@ public SparseCategoricalCrossentropy(boolean fromLogits) { * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss * name, * - * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param reduction Type of Reduction to apply to loss. */ @@ -162,7 +157,6 @@ public SparseCategoricalCrossentropy(boolean fromLogits, Reduction reduction) { /** * Creates a SparseCategoricalCrossentropy * - * @param tf the TensorFlow Ops * @param name the name of this loss function * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param reduction Type of Reduction to apply to loss. @@ -184,6 +178,7 @@ public SparseCategoricalCrossentropy( * range o [0. to 1.]. In Eager Mode, this call will throw {@link IllegalArgumentException}, if * the predictions values are outside the range o [0. to 1.] * + * @param tf the TensorFlow Ops * @param labels the truth values or labels * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar is diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index 6dfaf3f0d47..f5a8c956072 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -57,6 +57,7 @@ public L1L2(float l1, float l2) { /** * Creates an L1L2 regularizer * + * @param name the name for this regularizer, if null the class name will be used. * @param l1 L1 regularization factor, if null it is set to 0. * @param l2 L2 regularization factor, if null it is set to 0. * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} From c5ae13b905a4a8dda064009036bf0868fa1d521c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 2 Jun 2021 16:56:26 -0400 Subject: [PATCH 5/5] Results of Run mvn spotless:apply --- .../tensorflow/framework/activations/ELU.java | 4 +- .../framework/activations/HardSigmoid.java | 4 +- .../framework/activations/ReLU.java | 4 +- .../constraints/AbstractConstraint.java | 4 +- .../framework/constraints/MaxNorm.java | 4 +- .../framework/constraints/MinMaxNorm.java | 4 +- .../framework/constraints/UnitNorm.java | 4 +- .../tensorflow/framework/data/Dataset.java | 731 +++++++++--------- .../framework/data/DatasetOptional.java | 27 +- .../framework/initializers/Constant.java | 4 +- .../framework/initializers/Identity.java | 4 +- .../framework/initializers/Ones.java | 4 +- .../framework/initializers/Orthogonal.java | 4 +- .../framework/initializers/RandomNormal.java | 4 +- .../framework/initializers/RandomUniform.java | 4 +- .../initializers/TruncatedNormal.java | 4 +- .../initializers/VarianceScaling.java | 4 +- .../framework/losses/BinaryCrossentropy.java | 4 +- .../losses/CategoricalCrossentropy.java | 4 +- .../tensorflow/framework/losses/Hinge.java | 4 +- .../losses/SparseCategoricalCrossentropy.java | 6 +- .../framework/losses/SquaredHinge.java | 4 +- .../org/tensorflow/framework/metrics/AUC.java | 19 +- .../framework/metrics/Accuracy.java | 4 +- .../framework/metrics/BinaryAccuracy.java | 4 +- .../metrics/CategoricalAccuracy.java | 4 +- .../metrics/CategoricalCrossentropy.java | 4 +- .../tensorflow/framework/metrics/MeanIoU.java | 11 +- .../framework/metrics/MeanRelativeError.java | 7 +- .../framework/metrics/MeanTensor.java | 9 +- .../framework/metrics/Precision.java | 15 +- .../framework/metrics/PrecisionAtRecall.java | 4 +- .../tensorflow/framework/metrics/Recall.java | 15 +- .../framework/metrics/RecallAtPrecision.java | 6 +- .../metrics/RootMeanSquaredError.java | 7 +- .../metrics/SensitivityAtSpecificity.java | 4 +- .../metrics/SparseCategoricalAccuracy.java | 7 +- .../metrics/SpecificityAtSensitivity.java | 4 +- .../metrics/TopKCategoricalAccuracy.java | 4 +- .../impl/ConfusionMatrixConditionCount.java | 12 +- .../metrics/impl/MeanMetricWrapper.java | 7 +- .../framework/metrics/impl/MetricsHelper.java | 21 +- .../impl/SensitivitySpecificityBase.java | 15 +- .../framework/metrics/impl/SetsOps.java | 4 +- .../framework/metrics/impl/SymbolicShape.java | 5 +- .../metrics/impl/WeightsBroadcastOps.java | 11 +- .../framework/regularizers/L1L2.java | 4 +- .../framework/activations/SELUTest.java | 11 +- .../framework/activations/TanhTest.java | 24 +- .../framework/constraints/MaxNormTest.java | 5 +- .../framework/constraints/MinMaxNormTest.java | 5 +- .../framework/initializers/ConstantTest.java | 6 +- .../framework/initializers/OnesTest.java | 6 +- .../losses/BinaryCrossentropyTest.java | 28 +- .../losses/CategoricalCrossentropyTest.java | 10 +- .../framework/losses/HingeTest.java | 4 +- .../SparseCategoricalCrossentropyTest.java | 4 +- .../framework/losses/SquaredHingeTest.java | 4 +- .../optimizers/GradientDescentTest.java | 9 +- .../framework/regularizers/L1L2Test.java | 4 +- .../framework/regularizers/L1Test.java | 4 +- .../framework/regularizers/L2Test.java | 4 +- 62 files changed, 579 insertions(+), 596 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java index 919a947a127..bd019a60df1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java @@ -14,13 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.activations; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TFloating; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Exponential linear unit. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java index fac4d14eca5..4365e0cd14a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java @@ -14,12 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.activations; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Hard sigmoid activation. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java index c966e5d9ddd..44dd3bc3b46 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java @@ -14,14 +14,14 @@ =======================================================================*/ package org.tensorflow.framework.activations; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.op.math.Greater; import org.tensorflow.op.nn.LeakyRelu; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Rectified Linear Unit(ReLU) activation. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java index 266d01620bd..15db0d4b1e0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/AbstractConstraint.java @@ -14,13 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.constraints; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** Base class for Constraints. AbstractConstraint subclasses impose constraints on weight values */ public abstract class AbstractConstraint implements Constraint { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java index b9f082f54de..9bb99c47d07 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java @@ -14,12 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.constraints; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Constrains the weights incident to each hidden unit to have a norm less than or equal to a * desired value. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java index 97e86d7693f..49b06744253 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java @@ -14,12 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.constraints; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** Constrains the weights to have the norm between a lower bound and an upper bound. */ public class MinMaxNorm extends AbstractConstraint { public static final double MIN_VALUE_DEFAULT = 0.0; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java index fdd71945229..8410605fab0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java @@ -14,12 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.constraints; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** Constrains the weights to have unit norm. */ public class UnitNorm extends AbstractConstraint { public static final int AXIS_DEFAULT = 0; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java index e227ed5d305..8ae751823fe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java @@ -1,366 +1,365 @@ -/* - * Copyright 2020 The TensorFlow Authors. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.framework.data; - -import org.tensorflow.Operand; -import org.tensorflow.framework.data.impl.BatchDataset; -import org.tensorflow.framework.data.impl.MapDataset; -import org.tensorflow.framework.data.impl.SkipDataset; -import org.tensorflow.framework.data.impl.TFRecordDataset; -import org.tensorflow.framework.data.impl.TakeDataset; -import org.tensorflow.framework.data.impl.TensorSliceDataset; -import org.tensorflow.framework.data.impl.TextLineDataset; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TType; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.function.Function; - -/** - * Represents a potentially large list of independent elements (samples), and allows iteration and - * transformations to be performed across these elements. - */ -public abstract class Dataset implements Iterable>> { - protected Ops tf; - private Operand variant; - private List> outputTypes; - private List outputShapes; - - /** - * Creates a Dataset - * - * @param tf the TensorFlow Ops - * @param variant the tensor that represents the dataset. - * @param outputTypes a list of output types produced by this data set. - * @param outputShapes a list of output shapes produced by this data set. - */ - public Dataset( - Ops tf, - Operand variant, - List> outputTypes, - List outputShapes) { - if (tf == null) { - throw new IllegalArgumentException("Ops accessor cannot be null."); - } - - if (outputTypes.size() != outputShapes.size()) { - throw new IllegalArgumentException( - "`outputTypes` and " + "`outputShapes` must have the same size."); - } - - this.tf = tf; - this.variant = variant; - this.outputTypes = outputTypes; - this.outputShapes = outputShapes; - } - - /** - * Creates a Dataset that is a copy of another Dataset - * - * @param other the other Dataset - */ - protected Dataset(Dataset other) { - this.tf = other.tf; - this.variant = other.variant; - this.outputTypes = other.outputTypes; - this.outputShapes = other.outputShapes; - } - - /** - * Groups elements of this dataset into batches. - * - * @param batchSize The number of desired elements per batch - * @param dropLastBatch Whether to leave out the final batch if it has fewer than `batchSize` - * elements. - * @return A batched Dataset - */ - public final Dataset batch(long batchSize, boolean dropLastBatch) { - - List batchOutputShapes = new ArrayList<>(); - outputShapes.forEach(s -> batchOutputShapes.add(s.prepend(-1))); - - return new BatchDataset( - tf, - this.getVariant(), - tf.constant(batchSize), - tf.constant(dropLastBatch), - outputTypes, - batchOutputShapes); - } - - /** - * Groups elements of this dataset into batches. Includes the last batch, even if it has fewer - * than `batchSize` elements. - * - * @param batchSize The number of desired elements per batch - * @return A batched Dataset - */ - public final Dataset batch(long batchSize) { - return batch(batchSize, false); - } - - /** - * Returns a new `Dataset` which skips `count` initial elements from this dataset - * - * @param count The number of elements to `skip` to form the new dataset. - * @return A new Dataset with `count` elements removed. - */ - public final Dataset skip(long count) { - return new SkipDataset( - tf, this.getVariant(), tf.constant(count), this.getOutputTypes(), this.getOutputShapes()); - } - - /** - * Returns a new `Dataset` with only the first `count` elements from this dataset. - * - * @param count The number of elements to "take" from this dataset. - * @return A new Dataset containing the first `count` elements from this dataset. - */ - public final Dataset take(long count) { - return new TakeDataset( - tf, this.getVariant(), tf.constant(count), this.getOutputTypes(), this.getOutputShapes()); - } - - /** - * Returns a new Dataset which maps a function across all elements from this dataset, on a single - * component of each element. - * - *

        For example, suppose each element is a {@code List>} with 2 components: - * (features, labels). - * - *

        Calling {@code dataset.mapOneComponent(0, features -> tf.math.mul(features, - * tf.constant(2)))} will map the function over the `features` component of each element, - * multiplying each by 2. - * - * @param index The index of the component to transform. - * @param mapper The function to apply to the target component. - * @return A new Dataset applying `mapper` to the component at the chosen index. - */ - public Dataset mapOneComponent(int index, Function, Operand> mapper) { - return map( - outputs -> { - List> newComponents = new ArrayList<>(outputs); - newComponents.set(index, mapper.apply(outputs.get(index))); - return newComponents; - }); - } - - /** - * Returns a new Dataset which maps a function across all elements from this dataset, on all - * components of each element. - * - *

        For example, suppose each element is a {@code List>} with 2 components: - * (features, labels). - * - *

        Calling {@code dataset.mapAllComponents(component -> tf.math.mul(component, - * tf.constant(2)))} will map the function over the both the `features` and `labels` components of - * each element, multiplying them all by 2 - * - * @param mapper The function to apply to each component - * @return A new Dataset applying `mapper` to all components of each element. - */ - public Dataset mapAllComponents(Function, Operand> mapper) { - return map( - outputs -> { - List> mappedOutputs = new ArrayList<>(); - outputs.forEach(o -> mappedOutputs.add(mapper.apply(o))); - return mappedOutputs; - }); - } - - /** - * Returns a new Dataset which maps a function over all elements returned by this dataset. - * - *

        For example, suppose each element is a {@code List>} with 2 components: - * (features, labels). - * - *

        Calling - * - *

        {@code
        -   * dataset.map(components -> {
        -   *      Operand features = components.get(0);
        -   *      Operand labels   = components.get(1);
        -   *
        -   *      return Arrays.asList(
        -   *        tf.math.mul(features, tf.constant(2)),
        -   *        tf.math.mul(labels, tf.constant(5))
        -   *      );
        -   * });
        -   * }
        - * - * will map the function over the `features` and `labels` components, multiplying features by 2, - * and multiplying the labels by 5. - * - * @param mapper The function to apply to each element of this iterator. - * @return A new Dataset applying `mapper` to each element of this iterator. - */ - public Dataset map(Function>, List>> mapper) { - return new MapDataset(this, mapper); - } - - /** - * Creates an iterator which iterates through all batches of this Dataset in an eager fashion. - * Each batch is a list of components, returned as `Output` objects. - * - *

        This method enables for-each iteration through batches when running in eager mode. For Graph - * mode batch iteration, see `makeOneShotIterator`. - * - * @return an Iterator through batches of this dataset. - */ - @Override - public Iterator>> iterator() { - return makeOneShotIterator().iterator(); - } - - /** - * Creates a `DatasetIterator` that can be used to iterate over elements of this dataset. - * - *

        This iterator will have to be initialized with a call to `iterator.makeInitializer(Dataset)` - * before elements can be retreived in a loop. - * - * @return A new `DatasetIterator` based on this dataset's structure. - */ - public DatasetIterator makeInitializeableIterator() { - DatasetIterator iterator = DatasetIterator.fromStructure(tf, outputTypes, outputShapes); - iterator.makeInitializer(this); - return iterator; - } - - /** - * Creates a `DatasetIterator` that can be used to iterate over elements of this dataset. Using - * `makeOneShotIterator` ensures that the iterator is automatically initialized on this dataset. - * skips In graph mode, the initializer op will be added to the Graph's intitializer list, which - * must be run via `tf.init()`: - * - *

        Ex: - * - *

        -   *     try (Session session = new Session(graph) {
        -   *         // Immediately run initializers
        -   *         session.run(tf.init());
        -   *     }
        -   * 
        - * - *

        In eager mode, the initializer will be run automatically as a result of this call. - * - * @return A new `DatasetIterator` based on this dataset's structure. - */ - public DatasetIterator makeOneShotIterator() { - DatasetIterator iterator = makeInitializeableIterator(); - Op initializer = iterator.makeInitializer(this); - if (tf.scope().env().isGraph()) tf.initAdd(initializer); - return iterator; - } - - /** - * Creates an in-memory `Dataset` whose elements are slices of the given tensors. Each element of - * this dataset will be a {@code List>}, representing slices (e.g. batches) of the - * provided tensors. - * - * @param tf Ops Accessor - * @param tensors A list of {@code Operand} representing components of this dataset (e.g. - * features, labels) - * @param outputTypes A list of tensor type classes representing the data type of each component - * of this dataset. - * @return A new `Dataset` - */ - public static Dataset fromTensorSlices( - Ops tf, List> tensors, List> outputTypes) { - return new TensorSliceDataset(tf, tensors, outputTypes); - } - - /** - * Creates a TFRecordDataset from a file containing TFRecords - * - * @param tf the TensorFlow Ops - * @param filename the file name that holds the TFRecords - * @param compressionType the compresstion type for the file - * @param bufferSize the buffersize for processing the TFRecords file. - * @return a TFRecordDataset - */ - public static Dataset tfRecordDataset( - Ops tf, String filename, String compressionType, long bufferSize) { - return new TFRecordDataset( - tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize)); - } - - /** - * Creates a TextLineDataset from a file containing one recored per ling. - * - * @param tf the TensorFlow Ops - * @param filename the file name that holds the data records - * @param compressionType the compresstion type for the file - * @param bufferSize the buffersize for processing the records file. - * @return a TextLineDataset - */ - public static Dataset textLineDataset( - Ops tf, String filename, String compressionType, long bufferSize) { - return new TextLineDataset( - tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize)); - } - - /** - * Gets the variant tensor representing this dataset. - * - * @return the variant tensor representing this dataset. - */ - public Operand getVariant() { - return variant; - } - - /** - * Gets a list of output types for each component of this dataset. - * - * @return list of output types for each component of this dataset. - */ - public List> getOutputTypes() { - return this.outputTypes; - } - - /** - * Gets a list of shapes for each component of this dataset. - * - * @return a list of shapes for each component of this dataset. - */ - public List getOutputShapes() { - return this.outputShapes; - } - - /** - * Gets the TensorFlow Ops instance for this dataset - * - * @return the TensorFlow Ops instance for this dataset - */ - public Ops getOpsInstance() { - return this.tf; - } - - /** {@inheritDoc} */ - @Override - public String toString() { - return "Dataset{" - + "outputTypes=" - + Arrays.toString(getOutputTypes().stream().map(Class::getSimpleName).toArray()) - + ", outputShapes=" - + Arrays.toString(getOutputShapes().stream().map(Shape::toString).toArray()) - + "}"; - } -} +/* + * Copyright 2020 The TensorFlow Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.framework.data; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.function.Function; +import org.tensorflow.Operand; +import org.tensorflow.framework.data.impl.BatchDataset; +import org.tensorflow.framework.data.impl.MapDataset; +import org.tensorflow.framework.data.impl.SkipDataset; +import org.tensorflow.framework.data.impl.TFRecordDataset; +import org.tensorflow.framework.data.impl.TakeDataset; +import org.tensorflow.framework.data.impl.TensorSliceDataset; +import org.tensorflow.framework.data.impl.TextLineDataset; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TType; + +/** + * Represents a potentially large list of independent elements (samples), and allows iteration and + * transformations to be performed across these elements. + */ +public abstract class Dataset implements Iterable>> { + protected Ops tf; + private Operand variant; + private List> outputTypes; + private List outputShapes; + + /** + * Creates a Dataset + * + * @param tf the TensorFlow Ops + * @param variant the tensor that represents the dataset. + * @param outputTypes a list of output types produced by this data set. + * @param outputShapes a list of output shapes produced by this data set. + */ + public Dataset( + Ops tf, + Operand variant, + List> outputTypes, + List outputShapes) { + if (tf == null) { + throw new IllegalArgumentException("Ops accessor cannot be null."); + } + + if (outputTypes.size() != outputShapes.size()) { + throw new IllegalArgumentException( + "`outputTypes` and " + "`outputShapes` must have the same size."); + } + + this.tf = tf; + this.variant = variant; + this.outputTypes = outputTypes; + this.outputShapes = outputShapes; + } + + /** + * Creates a Dataset that is a copy of another Dataset + * + * @param other the other Dataset + */ + protected Dataset(Dataset other) { + this.tf = other.tf; + this.variant = other.variant; + this.outputTypes = other.outputTypes; + this.outputShapes = other.outputShapes; + } + + /** + * Groups elements of this dataset into batches. + * + * @param batchSize The number of desired elements per batch + * @param dropLastBatch Whether to leave out the final batch if it has fewer than `batchSize` + * elements. + * @return A batched Dataset + */ + public final Dataset batch(long batchSize, boolean dropLastBatch) { + + List batchOutputShapes = new ArrayList<>(); + outputShapes.forEach(s -> batchOutputShapes.add(s.prepend(-1))); + + return new BatchDataset( + tf, + this.getVariant(), + tf.constant(batchSize), + tf.constant(dropLastBatch), + outputTypes, + batchOutputShapes); + } + + /** + * Groups elements of this dataset into batches. Includes the last batch, even if it has fewer + * than `batchSize` elements. + * + * @param batchSize The number of desired elements per batch + * @return A batched Dataset + */ + public final Dataset batch(long batchSize) { + return batch(batchSize, false); + } + + /** + * Returns a new `Dataset` which skips `count` initial elements from this dataset + * + * @param count The number of elements to `skip` to form the new dataset. + * @return A new Dataset with `count` elements removed. + */ + public final Dataset skip(long count) { + return new SkipDataset( + tf, this.getVariant(), tf.constant(count), this.getOutputTypes(), this.getOutputShapes()); + } + + /** + * Returns a new `Dataset` with only the first `count` elements from this dataset. + * + * @param count The number of elements to "take" from this dataset. + * @return A new Dataset containing the first `count` elements from this dataset. + */ + public final Dataset take(long count) { + return new TakeDataset( + tf, this.getVariant(), tf.constant(count), this.getOutputTypes(), this.getOutputShapes()); + } + + /** + * Returns a new Dataset which maps a function across all elements from this dataset, on a single + * component of each element. + * + *

        For example, suppose each element is a {@code List>} with 2 components: + * (features, labels). + * + *

        Calling {@code dataset.mapOneComponent(0, features -> tf.math.mul(features, + * tf.constant(2)))} will map the function over the `features` component of each element, + * multiplying each by 2. + * + * @param index The index of the component to transform. + * @param mapper The function to apply to the target component. + * @return A new Dataset applying `mapper` to the component at the chosen index. + */ + public Dataset mapOneComponent(int index, Function, Operand> mapper) { + return map( + outputs -> { + List> newComponents = new ArrayList<>(outputs); + newComponents.set(index, mapper.apply(outputs.get(index))); + return newComponents; + }); + } + + /** + * Returns a new Dataset which maps a function across all elements from this dataset, on all + * components of each element. + * + *

        For example, suppose each element is a {@code List>} with 2 components: + * (features, labels). + * + *

        Calling {@code dataset.mapAllComponents(component -> tf.math.mul(component, + * tf.constant(2)))} will map the function over the both the `features` and `labels` components of + * each element, multiplying them all by 2 + * + * @param mapper The function to apply to each component + * @return A new Dataset applying `mapper` to all components of each element. + */ + public Dataset mapAllComponents(Function, Operand> mapper) { + return map( + outputs -> { + List> mappedOutputs = new ArrayList<>(); + outputs.forEach(o -> mappedOutputs.add(mapper.apply(o))); + return mappedOutputs; + }); + } + + /** + * Returns a new Dataset which maps a function over all elements returned by this dataset. + * + *

        For example, suppose each element is a {@code List>} with 2 components: + * (features, labels). + * + *

        Calling + * + *

        {@code
        +   * dataset.map(components -> {
        +   *      Operand features = components.get(0);
        +   *      Operand labels   = components.get(1);
        +   *
        +   *      return Arrays.asList(
        +   *        tf.math.mul(features, tf.constant(2)),
        +   *        tf.math.mul(labels, tf.constant(5))
        +   *      );
        +   * });
        +   * }
        + * + * will map the function over the `features` and `labels` components, multiplying features by 2, + * and multiplying the labels by 5. + * + * @param mapper The function to apply to each element of this iterator. + * @return A new Dataset applying `mapper` to each element of this iterator. + */ + public Dataset map(Function>, List>> mapper) { + return new MapDataset(this, mapper); + } + + /** + * Creates an iterator which iterates through all batches of this Dataset in an eager fashion. + * Each batch is a list of components, returned as `Output` objects. + * + *

        This method enables for-each iteration through batches when running in eager mode. For Graph + * mode batch iteration, see `makeOneShotIterator`. + * + * @return an Iterator through batches of this dataset. + */ + @Override + public Iterator>> iterator() { + return makeOneShotIterator().iterator(); + } + + /** + * Creates a `DatasetIterator` that can be used to iterate over elements of this dataset. + * + *

        This iterator will have to be initialized with a call to `iterator.makeInitializer(Dataset)` + * before elements can be retreived in a loop. + * + * @return A new `DatasetIterator` based on this dataset's structure. + */ + public DatasetIterator makeInitializeableIterator() { + DatasetIterator iterator = DatasetIterator.fromStructure(tf, outputTypes, outputShapes); + iterator.makeInitializer(this); + return iterator; + } + + /** + * Creates a `DatasetIterator` that can be used to iterate over elements of this dataset. Using + * `makeOneShotIterator` ensures that the iterator is automatically initialized on this dataset. + * skips In graph mode, the initializer op will be added to the Graph's intitializer list, which + * must be run via `tf.init()`: + * + *

        Ex: + * + *

        +   *     try (Session session = new Session(graph) {
        +   *         // Immediately run initializers
        +   *         session.run(tf.init());
        +   *     }
        +   * 
        + * + *

        In eager mode, the initializer will be run automatically as a result of this call. + * + * @return A new `DatasetIterator` based on this dataset's structure. + */ + public DatasetIterator makeOneShotIterator() { + DatasetIterator iterator = makeInitializeableIterator(); + Op initializer = iterator.makeInitializer(this); + if (tf.scope().env().isGraph()) tf.initAdd(initializer); + return iterator; + } + + /** + * Creates an in-memory `Dataset` whose elements are slices of the given tensors. Each element of + * this dataset will be a {@code List>}, representing slices (e.g. batches) of the + * provided tensors. + * + * @param tf Ops Accessor + * @param tensors A list of {@code Operand} representing components of this dataset (e.g. + * features, labels) + * @param outputTypes A list of tensor type classes representing the data type of each component + * of this dataset. + * @return A new `Dataset` + */ + public static Dataset fromTensorSlices( + Ops tf, List> tensors, List> outputTypes) { + return new TensorSliceDataset(tf, tensors, outputTypes); + } + + /** + * Creates a TFRecordDataset from a file containing TFRecords + * + * @param tf the TensorFlow Ops + * @param filename the file name that holds the TFRecords + * @param compressionType the compresstion type for the file + * @param bufferSize the buffersize for processing the TFRecords file. + * @return a TFRecordDataset + */ + public static Dataset tfRecordDataset( + Ops tf, String filename, String compressionType, long bufferSize) { + return new TFRecordDataset( + tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize)); + } + + /** + * Creates a TextLineDataset from a file containing one recored per ling. + * + * @param tf the TensorFlow Ops + * @param filename the file name that holds the data records + * @param compressionType the compresstion type for the file + * @param bufferSize the buffersize for processing the records file. + * @return a TextLineDataset + */ + public static Dataset textLineDataset( + Ops tf, String filename, String compressionType, long bufferSize) { + return new TextLineDataset( + tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize)); + } + + /** + * Gets the variant tensor representing this dataset. + * + * @return the variant tensor representing this dataset. + */ + public Operand getVariant() { + return variant; + } + + /** + * Gets a list of output types for each component of this dataset. + * + * @return list of output types for each component of this dataset. + */ + public List> getOutputTypes() { + return this.outputTypes; + } + + /** + * Gets a list of shapes for each component of this dataset. + * + * @return a list of shapes for each component of this dataset. + */ + public List getOutputShapes() { + return this.outputShapes; + } + + /** + * Gets the TensorFlow Ops instance for this dataset + * + * @return the TensorFlow Ops instance for this dataset + */ + public Ops getOpsInstance() { + return this.tf; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return "Dataset{" + + "outputTypes=" + + Arrays.toString(getOutputTypes().stream().map(Class::getSimpleName).toArray()) + + ", outputShapes=" + + Arrays.toString(getOutputShapes().stream().map(Shape::toString).toArray()) + + "}"; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java index f1df7b78e94..45da020b105 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java @@ -15,13 +15,12 @@ */ package org.tensorflow.framework.data; +import java.util.ArrayList; +import java.util.List; import org.tensorflow.Operand; -import org.tensorflow.op.Ops; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; import org.tensorflow.types.TBool; - -import java.util.ArrayList; -import java.util.List; import org.tensorflow.types.family.TType; /** @@ -33,6 +32,7 @@ public class DatasetOptional { /** * Gets the optional variant for this Dataset + * * @return the optional variant for this Dataset */ public Operand getOptionalVariant() { @@ -45,13 +45,17 @@ public Operand getOptionalVariant() { /** * Creates a DatasetOptional dataset + * * @param tf the TensorFlow Ops * @param optionalVariant the tensor that represents the dataset. * @param outputTypes a list of output types produced by this data set. * @param outputShapes a list of output shapes produced by this data set. */ public DatasetOptional( - Ops tf, Operand optionalVariant, List> outputTypes, List outputShapes) { + Ops tf, + Operand optionalVariant, + List> outputTypes, + List outputShapes) { this.tf = tf; this.optionalVariant = optionalVariant; this.outputTypes = outputTypes; @@ -70,17 +74,19 @@ protected DatasetOptional(DatasetOptional other) { this.outputShapes = other.outputShapes; } - - - /** Gets the indicator of whether this optional has a value. + /** + * Gets the indicator of whether this optional has a value. + * * @return the indicator of whether this optional has a value. */ public Operand hasValue() { return tf.data.optionalHasValue(optionalVariant).hasValue(); } - /** Returns the value of the dataset element represented by this optional, if it exists. - * @return the value of the dataset element represented by this optional, if it exists. + /** + * Returns the value of the dataset element represented by this optional, if it exists. + * + * @return the value of the dataset element represented by this optional, if it exists. */ public List> getValue() { List> components = new ArrayList<>(); @@ -94,6 +100,7 @@ public List> getValue() { /** * Creates a DatasetOptional from components. + * * @param tf the TensorFlow Ops * @param components the components that constitute the DatasetOptional * @param outputTypes a list of output types produced by this data set. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java index 508fb69fd55..f8be105d357 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.initializers; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TBool; @@ -21,8 +23,6 @@ import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Initializer that generates tensors with a constant value. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java index 34a77520406..ea73f764a38 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.initializers; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.utils.ShapeUtils; import org.tensorflow.ndarray.Shape; @@ -21,8 +23,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Initializer that generates the identity matrix. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java index 6e818d30bd7..ee7e483dd69 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.initializers; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TBool; @@ -21,8 +23,6 @@ import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Initializer that generates tensors initialized to 1. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java index 519d0cd042e..240d915f97f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.initializers; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.framework.utils.ShapeUtils; @@ -23,8 +25,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Initializer that generates an orthogonal matrix. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java index 9a52a641416..fd8aa3a6766 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java @@ -14,13 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.initializers; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Initializer that generates tensors with a normal distribution. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java index 7288024f5b8..45ef6c4491d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.initializers; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.op.random.RandomUniformInt; @@ -21,8 +23,6 @@ import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Initializer that generates tensors with a uniform distribution. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java index 8069d5d9c7d..c5b23beef88 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java @@ -14,13 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.initializers; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Initializer that generates a truncated normal distribution. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java index a04e4a9a378..3ae493a8432 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.initializers; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.utils.ShapeUtils; import org.tensorflow.ndarray.Shape; @@ -21,8 +23,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Initializer capable of adapting its scale to the shape of weights tensors. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index 0c7c6abf8af..690396f2c28 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -14,14 +14,14 @@ =======================================================================*/ package org.tensorflow.framework.losses; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes the cross-entropy loss between true labels and predicted labels. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 7d65353b004..9b3ed8eb19d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -14,14 +14,14 @@ =======================================================================*/ package org.tensorflow.framework.losses; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes the crossentropy loss between the labels and predictions. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 05c5b47e329..9a443247996 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -14,14 +14,14 @@ =======================================================================*/ package org.tensorflow.framework.losses; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes the hinge loss between labels and predictions. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 13b4eb8225a..dff77bfc75b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -14,14 +14,14 @@ =======================================================================*/ package org.tensorflow.framework.losses; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes the crossentropy loss between labels and predictions. * @@ -83,8 +83,6 @@ public class SparseCategoricalCrossentropy extends AbstractLoss { * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss * name, a AbstractLoss Reduction of {@link AbstractLoss#REDUCTION_DEFAULT}, and fromLogits={@link * #FROM_LOGITS_DEFAULT}. - * - */ public SparseCategoricalCrossentropy() { this(null, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index c804b463984..2959e541892 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -14,14 +14,14 @@ =======================================================================*/ package org.tensorflow.framework.losses; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.AbstractLoss; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes the squared hinge loss between labels and predictions. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 69cb2ee0dfe..0ba94798c19 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -14,6 +14,15 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; @@ -27,16 +36,6 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Metric that computes the approximate AUC (Area under the curve) via a Riemann sum. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java index b8ec681cbfc..14f45020739 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.metrics.impl.LossMetric; @@ -23,8 +25,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Metric that calculates how often predictions equals labels. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java index a03677efd43..c27bf1b2acf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -14,14 +14,14 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Metric that calculates how often predictions matches binary labels. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java index 0cd90325e32..70dfebc508d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; @@ -22,8 +24,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Metric that calculates how often predictions matches one-hot labels. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index 4a32981aeeb..fa7c1a1a626 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; @@ -21,8 +23,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A Metric that computes the categorical cross-entropy loss between true labels and predicted * labels. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 04f4deb81cf..00ae3727249 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -14,6 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.Collections; +import java.util.List; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.metrics.impl.MetricsHelper; @@ -24,12 +29,6 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; -import java.util.Collections; -import java.util.List; - -import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes the mean Intersection-Over-Union metric. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index 8d92b97ec5f..915d281e44b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -14,6 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.List; import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; @@ -21,10 +24,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import java.util.List; - -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes the mean relative error by normalizing with the given values. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java index 583d9b2dde7..be09e7dd3f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -14,6 +14,10 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.ArrayList; +import java.util.List; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.losses.impl.LossTuple; @@ -26,11 +30,6 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; -import java.util.ArrayList; -import java.util.List; - -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Metric that computes the element-wise (weighted) mean of the given tensors. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index f81b32e8d76..f978c0e20da 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -14,6 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; @@ -25,14 +32,6 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes the precision of the predictions with respect to the labels. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java index 0bb49378f5b..a5285ff6b2d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -14,14 +14,14 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes best precision where recall is >= specified value. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java index 2780add994f..6cb87f5be9e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -14,6 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; @@ -25,14 +32,6 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes the recall of the predictions with respect to the labels. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java index e54def48fce..2386087e8a2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -14,6 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; import org.tensorflow.op.Ops; @@ -21,9 +24,6 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes best recall where precision is >= specified value. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java index 0d140eb96b3..8b0b06e788d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -15,6 +15,9 @@ */ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.List; import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; @@ -22,10 +25,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import java.util.List; - -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes root mean squared error metric between {@code labels} and {@code predictions} . * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java index 23a529ae1bb..3892af920e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -14,14 +14,14 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes best sensitivity where sensitivity is >= specified value. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 1d017ddf8fb..10d33c31508 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -14,6 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.Collections; import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; @@ -24,10 +27,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import java.util.Collections; - -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Calculates how often predictions matches integer labels. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java index 95d46c8fd06..aa8eeb062b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -14,14 +14,14 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes best specificity where sensitivity is >= specified value. {@code Sensitivity} * measures the proportion of actual positives that are correctly identified as such {@code (tp / diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java index b6e50c3295a..b630be5bcc2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -14,14 +14,14 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes the poisson loss metric between labels and predictions. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java index b031d80d0ef..4463e1f8213 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -14,6 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.metrics.Metric; @@ -24,12 +29,6 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Abstract base class that calculates the value of the given confusion matrix condition based on * labels and predictions. @@ -190,6 +189,7 @@ public float[] getThresholds() { /** * Gets the accumulatorName + * * @return the accumulatorName */ public String getAccumulatorName() { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index ec103197709..d9f4bb60cba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -14,6 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.List; import org.tensorflow.Operand; import org.tensorflow.framework.metrics.Mean; import org.tensorflow.framework.metrics.MetricReduction; @@ -21,10 +24,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import java.util.List; - -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of * {@link MetricReduction#WEIGHTED_MEAN}. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 51b8836ec83..7d265ef7651 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -14,6 +14,16 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; @@ -38,17 +48,6 @@ import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicLong; - -import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * These are helper methods for Metrics and will be module private when Java modularity is applied * to TensorFlow Java. These methods should not be used outside of the metrics packages. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java index e47ea4ea8e8..6779b6b1f5a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -1,5 +1,12 @@ package org.tensorflow.framework.metrics.impl; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.metrics.Metric; @@ -10,14 +17,6 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Abstract base class for computing sensitivity and specificity. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 0553b1edac7..dd77a1be4aa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -14,14 +14,14 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.op.SparseOps; import org.tensorflow.op.sparse.DenseToDenseSetOperation; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** Implementation of set operations */ public class SetsOps { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java index 7c3fda07ea9..b8698ab197d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java @@ -14,12 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; -import org.tensorflow.Operand; -import org.tensorflow.types.family.TNumber; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.tensorflow.Operand; +import org.tensorflow.types.family.TNumber; /** * A class that represents a Symbolic shape. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index 18b11700380..2df90a841ee 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -14,6 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -23,12 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Weight broadcasting operations. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index f5a8c956072..87db69f2a77 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -14,13 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.regularizers; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A regularizer that applies both L1 and L2 regularization penalties. * diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java index ef4644df18e..df1cfb9bd05 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java @@ -46,9 +46,14 @@ public void testCallFloat() { public void testCallDouble() { double[] input = {1, -2, 3, -4, -1, 2, -3, 4}; double[] expected = { - 1.0507009873554805, -1.520166468595695, 3.1521029620664414, - -1.7258986281898947, -1.1113307378125628, 2.101401974710961, - -1.670568728767112, 4.202803949421922, + 1.0507009873554805, + -1.520166468595695, + 3.1521029620664414, + -1.7258986281898947, + -1.1113307378125628, + 2.101401974710961, + -1.670568728767112, + 4.202803949421922, }; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java index 3988ec55bb3..696f96a367e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java @@ -30,10 +30,14 @@ public class TanhTest { public void testCallFloat() { float[] input = {1, -2, 3, -4, -5, 6, -7, 8}; float[] expected = { - 0.76159416F, -0.96402758F, - 0.99505475F, -0.9993293F, - -0.9999092F, 0.99998771F, - -0.99999834F, 0.99999977F + 0.76159416F, + -0.96402758F, + 0.99505475F, + -0.9993293F, + -0.9999092F, + 0.99998771F, + -0.99999834F, + 0.99999977F }; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { @@ -49,10 +53,14 @@ public void testCallFloat() { public void testCallDouble() { double[] input = {1, -2, 3, -4, -5, 6, -7, 8}; double[] expected = { - 0.76159416, -0.96402758, - 0.99505475, -0.9993293, - -0.9999092, 0.99998771, - -0.99999834, 0.99999977 + 0.76159416, + -0.96402758, + 0.99505475, + -0.9993293, + -0.9999092, + 0.99998771, + -0.99999834, + 0.99999977 }; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java index 259d6a963b5..c4f8f0ee89e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java @@ -1,5 +1,7 @@ package org.tensorflow.framework.constraints; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -7,9 +9,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; -import java.util.Random; -import java.util.concurrent.atomic.AtomicInteger; - class MaxNormTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java index 8b4c4007096..0d127b35b01 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java @@ -1,5 +1,7 @@ package org.tensorflow.framework.constraints; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.ND; @@ -10,9 +12,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; -import java.util.Random; -import java.util.concurrent.atomic.AtomicInteger; - class MinMaxNormTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java index 9291e5f83ef..5907deae547 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java @@ -14,6 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.initializers; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -27,9 +30,6 @@ import org.tensorflow.types.TString; import org.tensorflow.types.TUint8; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; - /** Test the Constant initializer */ public class ConstantTest { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java index 4872ce7ad8e..0bb0498e0cb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java @@ -14,6 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.initializers; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -27,9 +30,6 @@ import org.tensorflow.types.TString; import org.tensorflow.types.TUint8; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; - /** Test the Ones initializer */ public class OnesTest { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java index 0b662414e8f..d5afdfb0da4 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.losses; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -21,8 +23,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; -import static org.junit.jupiter.api.Assertions.assertThrows; - public class BinaryCrossentropyTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -43,9 +43,7 @@ public void testAllCorrectUnweighted() { testSession.evaluate(expected, loss); // Test with logits. float[] logitsArray = { - 100.0f, -100.0f, -100.0f, - -100.0f, 100.0f, -100.0f, - -100.0f, -100.0f, 100.0f + 100.0f, -100.0f, -100.0f, -100.0f, 100.0f, -100.0f, -100.0f, -100.0f, 100.0f }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); @@ -101,10 +99,7 @@ public void testUnweighted() { // Test with logits. float[] trueArray1 = {1f, 0f, 1f, 0f, 1f, 1f}; - float[] logitsArray = { - 100.0f, -100.0f, 100.0f, - 100.0f, 100.0f, -100.0f - }; + float[] logitsArray = {100.0f, -100.0f, 100.0f, 100.0f, 100.0f, -100.0f}; Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); @@ -135,10 +130,7 @@ public void testScalarWeighted() { // Test with logits. float[] trueArray1 = {1f, 0f, 1f, 0f, 1f, 1f}; - float[] logitsArray = { - 100.0f, -100.0f, 100.0f, - 100.0f, 100.0f, -100.0f - }; + float[] logitsArray = {100.0f, -100.0f, 100.0f, 100.0f, 100.0f, -100.0f}; Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); @@ -170,10 +162,7 @@ public void testSampleWeighted() { // Test with logits. float[] trueArray1 = {1f, 0f, 1f, 0f, 1f, 1f}; - float[] logitsArray = { - 100.0f, -100.0f, 100.0f, - 100.0f, 100.0f, -100.0f - }; + float[] logitsArray = {100.0f, -100.0f, 100.0f, 100.0f, 100.0f, -100.0f}; float[] sampleWeightArray1 = {4f, 3f}; Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); Operand logits = @@ -195,10 +184,7 @@ public void testNoReduction() { // Test with logits. float[] trueArray = {1f, 0f, 1f, 0f, 1f, 1f}; - float[] logitsArray = { - 100.0f, -100.0f, 100.0f, - 100.0f, 100.0f, -100.0f - }; + float[] logitsArray = {100.0f, -100.0f, 100.0f, 100.0f, 100.0f, -100.0f}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java index 3f6453b756a..25f5e5a54f1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.losses; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -23,8 +25,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; -import static org.junit.jupiter.api.Assertions.assertThrows; - public class CategoricalCrossentropyTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -90,11 +90,7 @@ public void testInvalidPredictionsRange() { 0L, 1L, 0L, 0L, 0L, 1L }; - float[] predArray = { - -1.F, 0.F, 0.F, - 0.F, 1.F, 0.F, - 0.F, 0.F, 1.F - }; + float[] predArray = {-1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java index d5fe846c82e..9ad9f35491c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.losses; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -21,8 +23,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; -import static org.junit.jupiter.api.Assertions.assertThrows; - public class HingeTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java index 113b89b82ff..d3fdcff03b7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.losses; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -23,8 +25,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; -import static org.junit.jupiter.api.Assertions.assertThrows; - public class SparseCategoricalCrossentropyTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java index 979e778e4c3..533e1179f7d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.losses; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -21,8 +23,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; -import static org.junit.jupiter.api.Assertions.assertThrows; - public class SquaredHingeTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index d957cfb2508..17188499ee7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -1,5 +1,9 @@ package org.tensorflow.framework.optimizers; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -29,11 +33,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - /** Test cases for GradientDescent Optimizer */ public class GradientDescentTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java index a4b98c002cb..00da1f7e789 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -1,5 +1,7 @@ package org.tensorflow.framework.regularizers; +import static org.junit.jupiter.api.Assertions.assertEquals; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -7,8 +9,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertEquals; - class L1L2Test extends CommonTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java index f7d540fb8e1..9a5efe2437e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java @@ -1,5 +1,7 @@ package org.tensorflow.framework.regularizers; +import static org.junit.jupiter.api.Assertions.assertEquals; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -7,8 +9,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertEquals; - class L1Test extends CommonTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java index 4579ccaf551..6153c36c38c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java @@ -1,5 +1,7 @@ package org.tensorflow.framework.regularizers; +import static org.junit.jupiter.api.Assertions.assertEquals; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -7,8 +9,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertEquals; - class L2Test extends CommonTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH};