diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java index 5fe51121b13..dfabc0be8c7 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java @@ -55,29 +55,7 @@ * * @param the type of values to be mapped */ -public interface NdArray { - - /** - * @return the shape of this N-dimensional array - */ - Shape shape(); - - /** - * @return the rank of this N-dimensional array - */ - default int rank() { - return shape().numDimensions(); - } - - /** - * Computes and returns the total size of this N-dimensional array, in number of values. - * - *

For example, given a 3x3x2 matrix, the return value will be 18. - * @return total size of this nd array - */ - default long size() { - return shape().size(); - } +public interface NdArray extends NdArrayBase { /** * Returns a sequence of all elements at a given dimension. diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrayBase.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrayBase.java new file mode 100644 index 00000000000..75112ce2470 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrayBase.java @@ -0,0 +1,47 @@ +/* + Copyright 2019 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray; + +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.index.Index; + +public interface NdArrayBase { + + /** + * @return the shape of this N-dimensional array + */ + Shape shape(); + + /** + * @return the rank of this N-dimensional array + */ + default int rank() { + return shape().numDimensions(); + } + + /** + * Computes and returns the total size of this N-dimensional array, in number of values. + * + *

For example, given a 3x3x2 matrix, the return value will be 18. + * @return total size of this nd array + */ + default long size() { + return shape().size(); + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java index 21b33402e98..8ad55cae7ed 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java @@ -65,7 +65,7 @@ public static ByteNdArray vectorOf(byte... values) { if (values == null) { throw new IllegalArgumentException("Values cannot be null"); } - return wrap(DataBuffers.of(values, false, false), Shape.of(values.length)); + return wrap(Shape.of(values.length), DataBuffers.of(values, false, false)); } /** @@ -81,19 +81,19 @@ public static ByteNdArray ofBytes(Shape shape) { if (shape == null) { throw new IllegalArgumentException("Shape cannot be null"); } - return wrap(DataBuffers.ofBytes(shape.size()), shape); + return wrap(shape, DataBuffers.ofBytes(shape.size())); } /** * Wraps a buffer in a byte N-dimensional array of a given shape. * - * @param buffer buffer to wrap * @param shape shape of the array + * @param buffer buffer to wrap * @return new byte N-dimensional array * @throws IllegalArgumentException if shape is null, has unknown dimensions or has size bigger * in the buffer size */ - public static ByteNdArray wrap(ByteDataBuffer buffer, Shape shape) { + public static ByteNdArray wrap(Shape shape, ByteDataBuffer buffer) { return ByteDenseNdArray.create(buffer, shape); } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/java_defs.h b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/java_defs.h index e41dc2dd9df..6028c4ea71d 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/java_defs.h +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/java_defs.h @@ -97,9 +97,6 @@ class Type { static Type IterableOf(const Type& type) { return Interface("Iterable").add_parameter(type); } - static Type DataTypeOf(const Type& type) { - return Class("DataType", "org.tensorflow").add_parameter(type); - } static Type ForDataType(DataType data_type) { switch (data_type) { case DataType::DT_BOOL: diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc index 843f3bdb247..2bec80a8457 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc @@ -177,7 +177,7 @@ void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class, if (attr.type().kind() == Type::GENERIC && default_types.find(attr.type().name()) != default_types.end()) { factory_statement << default_types.at(attr.type().name()).name() - << ".DTYPE"; + << ".class"; } else { AddArgument(attr.var(), attr.description(), &factory, &factory_doc); factory_statement << attr.var().name(); diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc index c9e0525edb7..8f676deeead 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc @@ -81,13 +81,19 @@ class TypeResolver { std::pair MakeTypePair(const Type& type) { return std::make_pair(type, type); } - Type NextGeneric() { + Type NextGeneric(const OpDef_AttrDef& attr_def) { char generic_letter = next_generic_letter_++; if (next_generic_letter_ > 'Z') { next_generic_letter_ = 'A'; } - return Type::Generic(string(1, generic_letter)) - .add_supertype(Type::Class("TType", "org.tensorflow.types.family")); + return Type::Generic(string(1, generic_letter)); + } + Type TypeFamilyOf(const OpDef_AttrDef& attr_def) { + // TODO(karllessard) support more type families + if (IsRealNumbers(attr_def.allowed_values())) { + return Type::Interface("TNumber", "org.tensorflow.types.family"); + } + return Type::Interface("TType", "org.tensorflow.types.family"); } }; @@ -152,15 +158,12 @@ std::pair TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, types = MakeTypePair(Type::Class("Shape", "org.tensorflow.ndarray")); } else if (attr_type == "tensor") { - types = MakeTypePair(Type::Class("Tensor", "org.tensorflow") - .add_parameter(Type::Wildcard())); + types = MakeTypePair(Type::Class("TType", "org.tensorflow.types.family")); } else if (attr_type == "type") { - Type type = *iterable_out ? Type::Wildcard() : NextGeneric(); - if (IsRealNumbers(attr_def.allowed_values())) { - type.add_supertype(Type::Class("TNumber", "org.tensorflow.types.family")); - } - types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow")); + Type type = *iterable_out ? Type::Wildcard() : NextGeneric(attr_def); + type.add_supertype(TypeFamilyOf(attr_def)); + types = MakeTypePair(type, Type::Class("Class")); } else { LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type @@ -306,7 +309,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, bool iterable = false; std::pair types = type_resolver->TypesOf(attr_def, &iterable); Type var_type = types.first.kind() == Type::GENERIC - ? Type::DataTypeOf(types.first) + ? Type::ClassOf(types.first) : types.first; if (iterable) { var_type = Type::ListOf(var_type); diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.cc b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.cc index 8598b1d945d..37315f0dff3 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.cc +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.cc @@ -85,6 +85,7 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) { SourceWriter& SourceWriter::AppendType(const Type& type) { if (type.wildcard()) { Append("?"); + WriteTypeBounds(type.supertypes()); } else { Append(type.name()); if (!type.parameters().empty()) { @@ -321,14 +322,27 @@ SourceWriter& SourceWriter::WriteGenerics( Append(", "); } Append(pt->name()); - if (!pt->supertypes().empty()) { - Append(" extends ").AppendType(pt->supertypes().front()); - } + WriteTypeBounds(pt->supertypes()); first = false; } return Append(">"); } +SourceWriter& SourceWriter::WriteTypeBounds( + const std::list& bounds) { + bool first = true; + for (const Type& bound : bounds) { + if (first) { + Append(" extends "); + first = false; + } else { + Append(" & "); + } + AppendType(bound); + } + return *this; +} + SourceWriter::GenericNamespace* SourceWriter::PushGenericNamespace( int modifiers) { GenericNamespace* generic_namespace; diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.h b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.h index 097887083e7..26b97f7a9c4 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.h +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.h @@ -213,6 +213,7 @@ class SourceWriter { SourceWriter& WriteJavadoc(const Javadoc& javadoc); SourceWriter& WriteAnnotations(const std::list& annotations); SourceWriter& WriteGenerics(const std::list& generics); + SourceWriter& WriteTypeBounds(const std::list& bounds); GenericNamespace* PushGenericNamespace(int modifiers); void PopGenericNamespace(); }; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataExperimentalOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataExperimentalOps.java index cccc4ac8dcb..411d8e7969a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataExperimentalOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataExperimentalOps.java @@ -18,12 +18,12 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.data.experimental.DataServiceDataset; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * An API for building {@code data.experimental} operations as {@link Op Op}s @@ -54,7 +54,7 @@ public final class DataExperimentalOps { public DataServiceDataset dataServiceDataset(Operand datasetId, Operand processingMode, Operand address, Operand protocol, Operand jobName, Operand maxOutstandingRequests, Operand iterationCounter, - List> outputTypes, List outputShapes, + List> outputTypes, List outputShapes, DataServiceDataset.Options... options) { return DataServiceDataset.create(scope, datasetId, processingMode, address, protocol, jobName, maxOutstandingRequests, iterationCounter, outputTypes, outputShapes, options); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java index 273025ef6bd..d01c9ee6d01 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java @@ -18,7 +18,6 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.data.AnonymousIterator; @@ -49,6 +48,7 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * An API for building {@code data} operations as {@link Op Op}s @@ -72,7 +72,7 @@ public final class DataOps { * @param outputShapes * @return a new instance of AnonymousIterator */ - public AnonymousIterator anonymousIterator(List> outputTypes, + public AnonymousIterator anonymousIterator(List> outputTypes, List outputShapes) { return AnonymousIterator.create(scope, outputTypes, outputShapes); } @@ -90,8 +90,8 @@ public AnonymousIterator anonymousIterator(List> outputTypes, * @return a new instance of BatchDataset */ public BatchDataset batchDataset(Operand inputDataset, Operand batchSize, - Operand dropRemainder, List> outputTypes, List outputShapes, - BatchDataset.Options... options) { + Operand dropRemainder, List> outputTypes, + List outputShapes, BatchDataset.Options... options) { return BatchDataset.create(scope, inputDataset, batchSize, dropRemainder, outputTypes, outputShapes, options); } @@ -126,7 +126,7 @@ public CSVDataset cSVDataset(Operand filenames, Operand compre * @return a new instance of ConcatenateDataset */ public ConcatenateDataset concatenateDataset(Operand inputDataset, Operand anotherDataset, - List> outputTypes, List outputShapes) { + List> outputTypes, List outputShapes) { return ConcatenateDataset.create(scope, inputDataset, anotherDataset, outputTypes, outputShapes); } @@ -161,8 +161,8 @@ public DeserializeIterator deserializeIterator(Operand resourceHandle, Operan * @param outputShapes * @return a new instance of Iterator */ - public Iterator iterator(String sharedName, String container, List> outputTypes, - List outputShapes) { + public Iterator iterator(String sharedName, String container, + List> outputTypes, List outputShapes) { return Iterator.create(scope, sharedName, container, outputTypes, outputShapes); } @@ -174,8 +174,8 @@ public Iterator iterator(String sharedName, String container, List> * @param outputShapes * @return a new instance of IteratorGetNext */ - public IteratorGetNext iteratorGetNext(Operand iterator, List> outputTypes, - List outputShapes) { + public IteratorGetNext iteratorGetNext(Operand iterator, + List> outputTypes, List outputShapes) { return IteratorGetNext.create(scope, iterator, outputTypes, outputShapes); } @@ -188,7 +188,7 @@ public IteratorGetNext iteratorGetNext(Operand iterator, List> ou * @return a new instance of IteratorGetNextAsOptional */ public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand iterator, - List> outputTypes, List outputShapes) { + List> outputTypes, List outputShapes) { return IteratorGetNextAsOptional.create(scope, iterator, outputTypes, outputShapes); } @@ -205,8 +205,8 @@ public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand iterator, * @param outputShapes * @return a new instance of IteratorGetNextSync */ - public IteratorGetNextSync iteratorGetNextSync(Operand iterator, List> outputTypes, - List outputShapes) { + public IteratorGetNextSync iteratorGetNextSync(Operand iterator, + List> outputTypes, List outputShapes) { return IteratorGetNextSync.create(scope, iterator, outputTypes, outputShapes); } @@ -252,8 +252,8 @@ public OptionalFromValue optionalFromValue(Iterable> components) { * @param outputShapes * @return a new instance of OptionalGetValue */ - public OptionalGetValue optionalGetValue(Operand optional, List> outputTypes, - List outputShapes) { + public OptionalGetValue optionalGetValue(Operand optional, + List> outputTypes, List outputShapes) { return OptionalGetValue.create(scope, optional, outputTypes, outputShapes); } @@ -287,7 +287,7 @@ public OptionalNone optionalNone() { * @return a new instance of RangeDataset */ public RangeDataset rangeDataset(Operand start, Operand stop, - Operand step, List> outputTypes, List outputShapes) { + Operand step, List> outputTypes, List outputShapes) { return RangeDataset.create(scope, start, stop, step, outputTypes, outputShapes); } @@ -302,7 +302,7 @@ public RangeDataset rangeDataset(Operand start, Operand stop, * @return a new instance of RepeatDataset */ public RepeatDataset repeatDataset(Operand inputDataset, Operand count, - List> outputTypes, List outputShapes) { + List> outputTypes, List outputShapes) { return RepeatDataset.create(scope, inputDataset, count, outputTypes, outputShapes); } @@ -329,7 +329,7 @@ public SerializeIterator serializeIterator(Operand resourceHandle, * @return a new instance of SkipDataset */ public SkipDataset skipDataset(Operand inputDataset, Operand count, - List> outputTypes, List outputShapes) { + List> outputTypes, List outputShapes) { return SkipDataset.create(scope, inputDataset, count, outputTypes, outputShapes); } @@ -345,7 +345,7 @@ public SkipDataset skipDataset(Operand inputDataset, Operand count, * @return a new instance of TakeDataset */ public TakeDataset takeDataset(Operand inputDataset, Operand count, - List> outputTypes, List outputShapes) { + List> outputTypes, List outputShapes) { return TakeDataset.create(scope, inputDataset, count, outputTypes, outputShapes); } @@ -406,8 +406,8 @@ public TfRecordDataset tfRecordDataset(Operand filenames, * @param outputShapes * @return a new instance of ZipDataset */ - public ZipDataset zipDataset(Iterable> inputDatasets, List> outputTypes, - List outputShapes) { + public ZipDataset zipDataset(Iterable> inputDatasets, + List> outputTypes, List outputShapes) { return ZipDataset.create(scope, inputDatasets, outputTypes, outputShapes); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DebuggingOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DebuggingOps.java deleted file mode 100644 index f12d18f925b..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DebuggingOps.java +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== -// -// This class has been generated, DO NOT EDIT! -// -package org.tensorflow.op; - -import org.tensorflow.Operand; -import org.tensorflow.op.debugging.CheckNumerics; -import org.tensorflow.types.family.TNumber; - -/** - * An API for building {@code debugging} operations as {@link Op Op}s - * - * @see {@link Ops} - */ -public final class DebuggingOps { - private final Scope scope; - - DebuggingOps(Scope scope) { - this.scope = scope; - } - - /** - * Checks a tensor for NaN and Inf values. - *

- * When run, reports an `InvalidArgument` error if `tensor` has any values - * that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. - * - * @param data type for {@code output()} output - * @param tensor - * @param message Prefix of the error message. - * @return a new instance of CheckNumerics - */ - public CheckNumerics checkNumerics(Operand tensor, String message) { - return CheckNumerics.create(scope, tensor, message); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DtypesOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DtypesOps.java index 16d571a6428..824f3383f2b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DtypesOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DtypesOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.dtypes.AsString; import org.tensorflow.op.dtypes.Cast; @@ -70,7 +69,7 @@ public AsString asString(Operand input, AsString.Options... * @param options carries optional attributes values * @return a new instance of Cast */ - public Cast cast(Operand x, DataType DstT, + public Cast cast(Operand x, Class DstT, Cast.Options... options) { return Cast.create(scope, x, DstT, options); } @@ -99,7 +98,7 @@ public Cast cast(Operand x, DataType * @return a new instance of Complex */ public Complex complex(Operand real, Operand imag, - DataType Tout) { + Class Tout) { return Complex.create(scope, real, imag, Tout); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ImageOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ImageOps.java index eea2fc4b8f1..b0fa7751b8a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ImageOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ImageOps.java @@ -18,7 +18,6 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.image.AdjustContrast; import org.tensorflow.op.image.AdjustHue; @@ -269,7 +268,7 @@ public CropAndResizeGradBoxes cropAndResizeGradBoxes(Operand */ public CropAndResizeGradImage cropAndResizeGradImage( Operand grads, Operand boxes, Operand boxInd, - Operand imageSize, DataType T, CropAndResizeGradImage.Options... options) { + Operand imageSize, Class T, CropAndResizeGradImage.Options... options) { return CropAndResizeGradImage.create(scope, grads, boxes, boxInd, imageSize, T, options); } @@ -460,7 +459,7 @@ public DecodePng decodePng(Operand contents, DecodePng.Options. * @param options carries optional attributes values * @return a new instance of DecodePng */ - public DecodePng decodePng(Operand contents, DataType dtype, + public DecodePng decodePng(Operand contents, Class dtype, DecodePng.Options... options) { return DecodePng.create(scope, contents, dtype, options); } @@ -622,7 +621,7 @@ public ExtractJpegShape extractJpegShape(Operand contents) { * @return a new instance of ExtractJpegShape */ public ExtractJpegShape extractJpegShape(Operand contents, - DataType outputType) { + Class outputType) { return ExtractJpegShape.create(scope, contents, outputType); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/IoOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/IoOps.java index adc656dc5af..78aa5897a97 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/IoOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/IoOps.java @@ -18,7 +18,6 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.io.DecodeBase64; @@ -168,7 +167,7 @@ public DecodeJsonExample decodeJsonExample(Operand jsonExamples) { * @return a new instance of DecodePaddedRaw */ public DecodePaddedRaw decodePaddedRaw(Operand inputBytes, - Operand fixedLength, DataType outType, DecodePaddedRaw.Options... options) { + Operand fixedLength, Class outType, DecodePaddedRaw.Options... options) { return DecodePaddedRaw.create(scope, inputBytes, fixedLength, outType, options); } @@ -181,7 +180,7 @@ public DecodePaddedRaw decodePaddedRaw(Operand i * @param options carries optional attributes values * @return a new instance of DecodeRaw */ - public DecodeRaw decodeRaw(Operand bytes, DataType outType, + public DecodeRaw decodeRaw(Operand bytes, Class outType, DecodeRaw.Options... options) { return DecodeRaw.create(scope, bytes, outType, options); } @@ -238,7 +237,7 @@ public DecodeRaw decodeRaw(Operand bytes, DataType * @return a new instance of DeserializeManySparse */ public DeserializeManySparse deserializeManySparse( - Operand serializedSparse, DataType dtype) { + Operand serializedSparse, Class dtype) { return DeserializeManySparse.create(scope, serializedSparse, dtype); } @@ -267,7 +266,8 @@ public EncodeBase64 encodeBase64(Operand input, EncodeBase64.Options... * @param options carries optional attributes values * @return a new instance of FifoQueue */ - public FifoQueue fifoQueue(List> componentTypes, FifoQueue.Options... options) { + public FifoQueue fifoQueue(List> componentTypes, + FifoQueue.Options... options) { return FifoQueue.create(scope, componentTypes, options); } @@ -331,7 +331,7 @@ public MatchingFiles matchingFiles(Operand pattern) { * @param options carries optional attributes values * @return a new instance of PaddingFifoQueue */ - public PaddingFifoQueue paddingFifoQueue(List> componentTypes, + public PaddingFifoQueue paddingFifoQueue(List> componentTypes, PaddingFifoQueue.Options... options) { return PaddingFifoQueue.create(scope, componentTypes, options); } @@ -393,9 +393,9 @@ public PaddingFifoQueue paddingFifoQueue(List> componentTypes, */ public ParseExample parseExample(Operand serialized, Operand names, Operand sparseKeys, Operand denseKeys, Operand raggedKeys, - Iterable> denseDefaults, Long numSparse, List> sparseTypes, - List> raggedValueTypes, List> raggedSplitTypes, - List denseShapes) { + Iterable> denseDefaults, Long numSparse, List> sparseTypes, + List> raggedValueTypes, + List> raggedSplitTypes, List denseShapes) { return ParseExample.create(scope, serialized, names, sparseKeys, denseKeys, raggedKeys, denseDefaults, numSparse, sparseTypes, raggedValueTypes, raggedSplitTypes, denseShapes); } @@ -451,10 +451,13 @@ public ParseSequenceExample parseSequenceExample(Operand serialized, Operand contextDenseKeys, Operand contextRaggedKeys, Operand featureListSparseKeys, Operand featureListDenseKeys, Operand featureListRaggedKeys, Operand featureListDenseMissingAssumedEmpty, - Iterable> contextDenseDefaults, List> contextSparseTypes, - List> contextRaggedValueTypes, List> contextRaggedSplitTypes, - List> featureListDenseTypes, List> featureListSparseTypes, - List> featureListRaggedValueTypes, List> featureListRaggedSplitTypes, + Iterable> contextDenseDefaults, List> contextSparseTypes, + List> contextRaggedValueTypes, + List> contextRaggedSplitTypes, + List> featureListDenseTypes, + List> featureListSparseTypes, + List> featureListRaggedValueTypes, + List> featureListRaggedSplitTypes, ParseSequenceExample.Options... options) { return ParseSequenceExample.create(scope, serialized, debugName, contextSparseKeys, contextDenseKeys, contextRaggedKeys, featureListSparseKeys, featureListDenseKeys, featureListRaggedKeys, featureListDenseMissingAssumedEmpty, contextDenseDefaults, contextSparseTypes, contextRaggedValueTypes, contextRaggedSplitTypes, featureListDenseTypes, featureListSparseTypes, featureListRaggedValueTypes, featureListRaggedSplitTypes, options); } @@ -496,7 +499,7 @@ public ParseSequenceExample parseSequenceExample(Operand serialized, */ public ParseSingleExample parseSingleExample(Operand serialized, Iterable> denseDefaults, Long numSparse, List sparseKeys, - List denseKeys, List> sparseTypes, List denseShapes) { + List denseKeys, List> sparseTypes, List denseShapes) { return ParseSingleExample.create(scope, serialized, denseDefaults, numSparse, sparseKeys, denseKeys, sparseTypes, denseShapes); } @@ -550,8 +553,9 @@ public ParseSingleSequenceExample parseSingleSequenceExample(Operand se Iterable> contextSparseKeys, Iterable> contextDenseKeys, Iterable> featureListSparseKeys, Iterable> featureListDenseKeys, Iterable> contextDenseDefaults, - Operand debugName, List> contextSparseTypes, - List> featureListDenseTypes, List> featureListSparseTypes, + Operand debugName, List> contextSparseTypes, + List> featureListDenseTypes, + List> featureListSparseTypes, ParseSingleSequenceExample.Options... options) { return ParseSingleSequenceExample.create(scope, serialized, featureListDenseMissingAssumedEmpty, contextSparseKeys, contextDenseKeys, featureListSparseKeys, featureListDenseKeys, contextDenseDefaults, debugName, contextSparseTypes, featureListDenseTypes, featureListSparseTypes, options); } @@ -566,7 +570,7 @@ public ParseSingleSequenceExample parseSingleSequenceExample(Operand se * @return a new instance of ParseTensor */ public ParseTensor parseTensor(Operand serialized, - DataType outType) { + Class outType) { return ParseTensor.create(scope, serialized, outType); } @@ -587,8 +591,8 @@ public ParseTensor parseTensor(Operand serialized, * @param options carries optional attributes values * @return a new instance of PriorityQueue */ - public PriorityQueue priorityQueue(List> componentTypes, List shapes, - PriorityQueue.Options... options) { + public PriorityQueue priorityQueue(List> componentTypes, + List shapes, PriorityQueue.Options... options) { return PriorityQueue.create(scope, componentTypes, shapes, options); } @@ -624,7 +628,7 @@ public QueueClose queueClose(Operand handle, QueueClose.Options... options) { * @param options carries optional attributes values * @return a new instance of QueueDequeue */ - public QueueDequeue queueDequeue(Operand handle, List> componentTypes, + public QueueDequeue queueDequeue(Operand handle, List> componentTypes, QueueDequeue.Options... options) { return QueueDequeue.create(scope, handle, componentTypes, options); } @@ -653,7 +657,7 @@ public QueueDequeue queueDequeue(Operand handle, List> componentT * @return a new instance of QueueDequeueMany */ public QueueDequeueMany queueDequeueMany(Operand handle, Operand n, - List> componentTypes, QueueDequeueMany.Options... options) { + List> componentTypes, QueueDequeueMany.Options... options) { return QueueDequeueMany.create(scope, handle, n, componentTypes, options); } @@ -685,7 +689,7 @@ public QueueDequeueMany queueDequeueMany(Operand handle, Operand n, * @return a new instance of QueueDequeueUpTo */ public QueueDequeueUpTo queueDequeueUpTo(Operand handle, Operand n, - List> componentTypes, QueueDequeueUpTo.Options... options) { + List> componentTypes, QueueDequeueUpTo.Options... options) { return QueueDequeueUpTo.create(scope, handle, n, componentTypes, options); } @@ -762,7 +766,7 @@ public QueueSize queueSize(Operand handle) { * @param options carries optional attributes values * @return a new instance of RandomShuffleQueue */ - public RandomShuffleQueue randomShuffleQueue(List> componentTypes, + public RandomShuffleQueue randomShuffleQueue(List> componentTypes, RandomShuffleQueue.Options... options) { return RandomShuffleQueue.create(scope, componentTypes, options); } @@ -914,7 +918,7 @@ public SerializeManySparse serializeManySparse( */ public SerializeManySparse serializeManySparse( Operand sparseIndices, Operand sparseValues, Operand sparseShape, - DataType outType) { + Class outType) { return SerializeManySparse.create(scope, sparseIndices, sparseValues, sparseShape, outType); } @@ -945,7 +949,7 @@ public SerializeSparse serializeSparse(Operand SerializeSparse serializeSparse( Operand sparseIndices, Operand sparseValues, Operand sparseShape, - DataType outType) { + Class outType) { return SerializeSparse.create(scope, sparseIndices, sparseValues, sparseShape, outType); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/LinalgOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/LinalgOps.java index b2242f1068e..8163e0a50b3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/LinalgOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/LinalgOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.linalg.BandPart; import org.tensorflow.op.linalg.BatchCholesky; @@ -396,7 +395,7 @@ public Det det(Operand input) { * @param options carries optional attributes values * @return a new instance of Eig */ - public Eig eig(Operand input, DataType Tout, + public Eig eig(Operand input, Class Tout, Eig.Options... options) { return Eig.create(scope, input, Tout, options); } @@ -682,7 +681,7 @@ public Lu lu(Operand input) { * @return a new instance of Lu */ public Lu lu(Operand input, - DataType outputIdxType) { + Class outputIdxType) { return Lu.create(scope, input, outputIdxType); } @@ -1373,7 +1372,7 @@ public Qr qr(Operand input, Qr.Options... options) { */ public QuantizedMatMul quantizedMatMul( Operand a, Operand b, Operand minA, Operand maxA, - Operand minB, Operand maxB, DataType Toutput, DataType Tactivation, + Operand minB, Operand maxB, Class Toutput, Class Tactivation, QuantizedMatMul.Options... options) { return QuantizedMatMul.create(scope, a, b, minA, maxA, minB, maxB, Toutput, Tactivation, options); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/LinalgSparseOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/LinalgSparseOps.java deleted file mode 100644 index 7f8777c883a..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/LinalgSparseOps.java +++ /dev/null @@ -1,463 +0,0 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== -// -// This class has been generated, DO NOT EDIT! -// -package org.tensorflow.op; - -import org.tensorflow.DataType; -import org.tensorflow.Operand; -import org.tensorflow.op.linalg.sparse.CSRSparseMatrixToSparseTensor; -import org.tensorflow.op.linalg.sparse.DenseToCSRSparseMatrix; -import org.tensorflow.op.linalg.sparse.SparseMatrixAdd; -import org.tensorflow.op.linalg.sparse.SparseMatrixMatMul; -import org.tensorflow.op.linalg.sparse.SparseMatrixMul; -import org.tensorflow.op.linalg.sparse.SparseMatrixNNZ; -import org.tensorflow.op.linalg.sparse.SparseMatrixOrderingAMD; -import org.tensorflow.op.linalg.sparse.SparseMatrixSoftmax; -import org.tensorflow.op.linalg.sparse.SparseMatrixSoftmaxGrad; -import org.tensorflow.op.linalg.sparse.SparseMatrixSparseCholesky; -import org.tensorflow.op.linalg.sparse.SparseMatrixSparseMatMul; -import org.tensorflow.op.linalg.sparse.SparseMatrixTranspose; -import org.tensorflow.op.linalg.sparse.SparseMatrixZeros; -import org.tensorflow.op.linalg.sparse.SparseTensorToCSRSparseMatrix; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; - -/** - * An API for building {@code linalg.sparse} operations as {@link Op Op}s - * - * @see {@link Ops} - */ -public final class LinalgSparseOps { - private final Scope scope; - - LinalgSparseOps(Scope scope) { - this.scope = scope; - } - - /** - * Converts a (possibly batched) CSRSparesMatrix to a SparseTensor. - * - * @param data type for {@code values()} output - * @param sparseMatrix A (possibly batched) CSRSparseMatrix. - * @param type - * @return a new instance of CSRSparseMatrixToSparseTensor - */ - public CSRSparseMatrixToSparseTensor cSRSparseMatrixToSparseTensor( - Operand sparseMatrix, DataType type) { - return CSRSparseMatrixToSparseTensor.create(scope, sparseMatrix, type); - } - - /** - * Converts a dense tensor to a (possibly batched) CSRSparseMatrix. - * - * @param denseInput A Dense tensor. - * @param indices Indices of nonzero elements. - * @return a new instance of DenseToCSRSparseMatrix - */ - public DenseToCSRSparseMatrix denseToCSRSparseMatrix(Operand denseInput, - Operand indices) { - return DenseToCSRSparseMatrix.create(scope, denseInput, indices); - } - - /** - * Sparse addition of two CSR matrices, C = alpha * A + beta * B. - *

- * The gradients of SparseMatrixAdd outputs with respect to alpha and beta are not - * currently defined (TensorFlow will return zeros for these entries). - * - * @param a A CSRSparseMatrix. - * @param b A CSRSparseMatrix. - * @param alpha A constant scalar. - * @param beta A constant scalar. - * @return a new instance of SparseMatrixAdd - */ - public SparseMatrixAdd sparseMatrixAdd(Operand a, Operand b, - Operand alpha, Operand beta) { - return SparseMatrixAdd.create(scope, a, b, alpha, beta); - } - - /** - * Matrix-multiplies a sparse matrix with a dense matrix. - *

- * Returns a dense matrix. - * For inputs A and B, where A is CSR and B is dense; this op returns a dense C; - *

- * If transpose_output is false, returns: - *

{@code
-   *    C = A . B
-   *  }
- * If transpose_output is `true`, returns: - *
{@code
-   *    C = transpose(A . B) = transpose(B) . transpose(A)
-   *  }
- * where the transposition is performed along the two innermost (matrix) - * dimensions. - *

- * If conjugate_output is `true`, returns: - *

{@code
-   *    C = conjugate(A . B) = conjugate(A) . conjugate(B)
-   *  }
- * If both conjugate_output and transpose_output are `true`, returns: - *
{@code
-   *    C = conjugate(transpose(A . B)) = conjugate(transpose(B)) .
-   *                                      conjugate(transpose(A))
-   *  }
- * - * @param data type for {@code output()} output - * @param a A CSRSparseMatrix. - * @param b A dense tensor. - * @param options carries optional attributes values - * @return a new instance of SparseMatrixMatMul - */ - public SparseMatrixMatMul sparseMatrixMatMul(Operand a, Operand b, - SparseMatrixMatMul.Options... options) { - return SparseMatrixMatMul.create(scope, a, b, options); - } - - /** - * Element-wise multiplication of a sparse matrix with a dense tensor. - *

- * Returns a sparse matrix. - *

- * The dense tensor `b` may be either a scalar; otherwise `a` must be a rank-3 - * `SparseMatrix`; in this case `b` must be shaped `[batch_size, 1, 1]` and the - * multiply operation broadcasts. - *

- * NOTE even if `b` is zero, the sparsity structure of the output does not - * change. - * - * @param a A CSRSparseMatrix. - * @param b A dense tensor. - * @return a new instance of SparseMatrixMul - */ - public SparseMatrixMul sparseMatrixMul(Operand a, Operand b) { - return SparseMatrixMul.create(scope, a, b); - } - - /** - * Returns the number of nonzeroes of `sparse_matrix`. - * - * @param sparseMatrix A CSRSparseMatrix. - * @return a new instance of SparseMatrixNNZ - */ - public SparseMatrixNNZ sparseMatrixNNZ(Operand sparseMatrix) { - return SparseMatrixNNZ.create(scope, sparseMatrix); - } - - /** - * Computes the Approximate Minimum Degree (AMD) ordering of `input`. - *

- * Computes the Approximate Minimum Degree (AMD) ordering for a sparse matrix. - *

- * The returned permutation may be used to permute the rows and columns of the - * given sparse matrix. This typically results in permuted sparse matrix's sparse - * Cholesky (or other decompositions) in having fewer zero fill-in compared to - * decomposition of the original matrix. - *

- * The input sparse matrix may have rank 2 or rank 3. The output Tensor, - * representing would then have rank 1 or 2 respectively, with the same batch - * shape as the input. - *

- * Each component of the input sparse matrix must represent a square symmetric - * matrix; only the lower triangular part of the matrix is read. The values of the - * sparse matrix does not affect the returned permutation, only the sparsity - * pattern of the sparse matrix is used. Hence, a single AMD ordering may be - * reused for the Cholesky decompositions of sparse matrices with the same sparsity - * pattern but with possibly different values. - *

- * Each batch component of the output permutation represents a permutation of `N` - * elements, where the input sparse matrix components each have `N` rows. That is, - * the component contains each of the integers `{0, .. N-1}` exactly once. The - * `i`th element represents the row index that the `i`th row maps to. - *

- * Usage example: - *

{@code
-   *      from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
-   *
-   *      a_indices = np.array([[0, 0], [1, 1], [2, 1], [2, 2], [3, 3]])
-   *      a_values = np.array([1.0, 2.0, 1.0, 3.0, 4.0], np.float32)
-   *      a_dense_shape = [4, 4]
-   *
-   *      with tf.Session() as sess:
-   *        # Define (COO format) SparseTensor over Numpy array.
-   *        a_st = tf.SparseTensor(a_indices, a_values, a_dense_shape)
-   *
-   *        # Convert SparseTensors to CSR SparseMatrix.
-   *        a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
-   *            a_st.indices, a_st.values, a_st.dense_shape)
-   *
-   *        # Obtain the AMD Ordering for the CSR SparseMatrix.
-   *        ordering_amd = sparse_csr_matrix_ops.sparse_matrix_ordering_amd(sparse_matrix)
-   *
-   *        ordering_amd_value = sess.run(ordering_amd)
-   *  }
- * `ordering_amd_value` stores the AMD ordering: `[1 2 3 0]`. - *

- * input: A `CSRSparseMatrix`. - * - * @param input A `CSRSparseMatrix`. - * @return a new instance of SparseMatrixOrderingAMD - */ - public SparseMatrixOrderingAMD sparseMatrixOrderingAMD(Operand input) { - return SparseMatrixOrderingAMD.create(scope, input); - } - - /** - * Calculates the softmax of a CSRSparseMatrix. - *

- * Calculate the softmax of the innermost dimensions of a SparseMatrix. - *

- * Missing values are treated as `-inf` (i.e., logits of zero probability); and - * the output has the same sparsity structure as the input (though missing values - * in the output may now be treated as having probability zero). - * - * @param logits A CSRSparseMatrix. - * @param type - * @return a new instance of SparseMatrixSoftmax - */ - public SparseMatrixSoftmax sparseMatrixSoftmax(Operand logits, - DataType type) { - return SparseMatrixSoftmax.create(scope, logits, type); - } - - /** - * Calculates the gradient of the SparseMatrixSoftmax op. - * - * @param softmax A CSRSparseMatrix. - * @param gradSoftmax The gradient of `softmax`. - * @param type - * @return a new instance of SparseMatrixSoftmaxGrad - */ - public SparseMatrixSoftmaxGrad sparseMatrixSoftmaxGrad(Operand softmax, - Operand gradSoftmax, DataType type) { - return SparseMatrixSoftmaxGrad.create(scope, softmax, gradSoftmax, type); - } - - /** - * Computes the sparse Cholesky decomposition of `input`. - *

- * Computes the Sparse Cholesky decomposition of a sparse matrix, with the given - * fill-in reducing permutation. - *

- * The input sparse matrix and the fill-in reducing permutation `permutation` must - * have compatible shapes. If the sparse matrix has rank 3; with the batch - * dimension `B`, then the `permutation` must be of rank 2; with the same batch - * dimension `B`. There is no support for broadcasting. - *

- * Furthermore, each component vector of `permutation` must be of length `N`, - * containing each of the integers {0, 1, ..., N - 1} exactly once, where `N` is - * the number of rows of each component of the sparse matrix. - *

- * Each component of the input sparse matrix must represent a symmetric positive - * definite (SPD) matrix; although only the lower triangular part of the matrix is - * read. If any individual component is not SPD, then an InvalidArgument error is - * thrown. - *

- * The returned sparse matrix has the same dense shape as the input sparse matrix. - * For each component `A` of the input sparse matrix, the corresponding output - * sparse matrix represents `L`, the lower triangular Cholesky factor satisfying - * the following identity: - *

{@code
-   *    A = L * Lt
-   *  }
- * where Lt denotes the transpose of L (or its conjugate transpose, if `type` is - * `complex64` or `complex128`). - *

- * The `type` parameter denotes the type of the matrix elements. The supported - * types are: `float32`, `float64`, `complex64` and `complex128`. - *

- * Usage example: - *

{@code
-   *      from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
-   *
-   *      a_indices = np.array([[0, 0], [1, 1], [2, 1], [2, 2], [3, 3]])
-   *      a_values = np.array([1.0, 2.0, 1.0, 3.0, 4.0], np.float32)
-   *      a_dense_shape = [4, 4]
-   *
-   *      with tf.Session() as sess:
-   *        # Define (COO format) SparseTensor over Numpy array.
-   *        a_st = tf.SparseTensor(a_indices, a_values, a_dense_shape)
-   *
-   *        # Convert SparseTensors to CSR SparseMatrix.
-   *        a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
-   *            a_st.indices, a_st.values, a_st.dense_shape)
-   *
-   *        # Obtain the Sparse Cholesky factor using AMD Ordering for reducing zero
-   *        # fill-in (number of structural non-zeros in the sparse Cholesky factor).
-   *        ordering_amd = sparse_csr_matrix_ops.sparse_matrix_ordering_amd(sparse_matrix)
-   *        cholesky_sparse_matrices = (
-   *            sparse_csr_matrix_ops.sparse_matrix_sparse_cholesky(
-   *                sparse_matrix, ordering_amd, type=tf.float32))
-   *
-   *        # Convert the CSRSparseMatrix Cholesky factor to a dense Tensor
-   *        dense_cholesky = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
-   *            cholesky_sparse_matrices, tf.float32)
-   *
-   *        # Evaluate the dense Tensor value.
-   *        dense_cholesky_value = sess.run(dense_cholesky)
-   *  }
- * `dense_cholesky_value` stores the dense Cholesky factor: - *
{@code
-   *      [[  1.  0.    0.    0.]
-   *       [  0.  1.41  0.    0.]
-   *       [  0.  0.70  1.58  0.]
-   *       [  0.  0.    0.    2.]]
-   *  }
- * input: A `CSRSparseMatrix`. - * permutation: A `Tensor`. - * type: The type of `input`. - * - * @param input A `CSRSparseMatrix`. - * @param permutation A fill-in reducing permutation matrix. - * @param type - * @return a new instance of SparseMatrixSparseCholesky - */ - public SparseMatrixSparseCholesky sparseMatrixSparseCholesky(Operand input, - Operand permutation, DataType type) { - return SparseMatrixSparseCholesky.create(scope, input, permutation, type); - } - - /** - * Sparse-matrix-multiplies two CSR matrices `a` and `b`. - *

- * Performs a matrix multiplication of a sparse matrix `a` with a sparse matrix - * `b`; returns a sparse matrix `a * b`, unless either `a` or `b` is transposed or - * adjointed. - *

- * Each matrix may be transposed or adjointed (conjugated and transposed) - * according to the Boolean parameters `transpose_a`, `adjoint_a`, `transpose_b` - * and `adjoint_b`. At most one of `transpose_a` or `adjoint_a` may be True. - * Similarly, at most one of `transpose_b` or `adjoint_b` may be True. - *

- * The inputs must have compatible shapes. That is, the inner dimension of `a` - * must be equal to the outer dimension of `b`. This requirement is adjusted - * according to whether either `a` or `b` is transposed or adjointed. - *

- * The `type` parameter denotes the type of the matrix elements. Both `a` and `b` - * must have the same type. The supported types are: `float32`, `float64`, - * `complex64` and `complex128`. - *

- * Both `a` and `b` must have the same rank. Broadcasting is not supported. If they - * have rank 3, each batch of 2D CSRSparseMatrices within `a` and `b` must have the - * same dense shape. - *

- * The sparse matrix product may have numeric (non-structural) zeros. - * TODO(anudhyan): Consider adding a boolean attribute to control whether to prune - * zeros. - *

- * Usage example: - *

{@code
-   *      from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
-   *
-   *      a_indices = np.array([[0, 0], [2, 3], [2, 4], [3, 0]])
-   *      a_values = np.array([1.0, 5.0, -1.0, -2.0], np.float32)
-   *      a_dense_shape = [4, 5]
-   *
-   *      b_indices = np.array([[0, 0], [3, 0], [3, 1]])
-   *      b_values = np.array([2.0, 7.0, 8.0], np.float32)
-   *      b_dense_shape = [5, 3]
-   *
-   *      with tf.Session() as sess:
-   *        # Define (COO format) Sparse Tensors over Numpy arrays
-   *        a_st = tf.SparseTensor(a_indices, a_values, a_dense_shape)
-   *        b_st = tf.SparseTensor(b_indices, b_values, b_dense_shape)
-   *
-   *        # Convert SparseTensors to CSR SparseMatrix
-   *        a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
-   *            a_st.indices, a_st.values, a_st.dense_shape)
-   *        b_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
-   *            b_st.indices, b_st.values, b_st.dense_shape)
-   *
-   *        # Compute the CSR SparseMatrix matrix multiplication
-   *        c_sm = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(
-   *            a=a_sm, b=b_sm, type=tf.float32)
-   *
-   *        # Convert the CSR SparseMatrix product to a dense Tensor
-   *        c_sm_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
-   *            c_sm, tf.float32)
-   *        # Evaluate the dense Tensor value
-   *        c_sm_dense_value = sess.run(c_sm_dense)
-   *  }
- * `c_sm_dense_value` stores the dense matrix product: - *
{@code
-   *      [[  2.   0.   0.]
-   *       [  0.   0.   0.]
-   *       [ 35.  40.   0.]
-   *       [ -4.   0.   0.]]
-   *  }
- * a: A `CSRSparseMatrix`. - * b: A `CSRSparseMatrix` with the same type and rank as `a`. - * type: The type of both `a` and `b`. - * transpose_a: If True, `a` transposed before multiplication. - * transpose_b: If True, `b` transposed before multiplication. - * adjoint_a: If True, `a` adjointed before multiplication. - * adjoint_b: If True, `b` adjointed before multiplication. - * - * @param a A CSRSparseMatrix. - * @param b A CSRSparseMatrix. - * @param type - * @param options carries optional attributes values - * @return a new instance of SparseMatrixSparseMatMul - */ - public SparseMatrixSparseMatMul sparseMatrixSparseMatMul(Operand a, - Operand b, DataType type, SparseMatrixSparseMatMul.Options... options) { - return SparseMatrixSparseMatMul.create(scope, a, b, type, options); - } - - /** - * Transposes the inner (matrix) dimensions of a CSRSparseMatrix. - *

- * Transposes the inner (matrix) dimensions of a SparseMatrix and optionally - * conjugates its values. - * - * @param input A CSRSparseMatrix. - * @param type - * @param options carries optional attributes values - * @return a new instance of SparseMatrixTranspose - */ - public SparseMatrixTranspose sparseMatrixTranspose(Operand input, - DataType type, SparseMatrixTranspose.Options... options) { - return SparseMatrixTranspose.create(scope, input, type, options); - } - - /** - * Creates an all-zeros CSRSparseMatrix with shape `dense_shape`. - * - * @param denseShape The desired matrix shape. - * @param type - * @return a new instance of SparseMatrixZeros - */ - public SparseMatrixZeros sparseMatrixZeros(Operand denseShape, - DataType type) { - return SparseMatrixZeros.create(scope, denseShape, type); - } - - /** - * Converts a SparseTensor to a (possibly batched) CSRSparseMatrix. - * - * @param indices SparseTensor indices. - * @param values SparseTensor values. - * @param denseShape SparseTensor dense shape. - * @return a new instance of SparseTensorToCSRSparseMatrix - */ - public SparseTensorToCSRSparseMatrix sparseTensorToCSRSparseMatrix( - Operand indices, Operand values, Operand denseShape) { - return SparseTensorToCSRSparseMatrix.create(scope, indices, values, denseShape); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/MathOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/MathOps.java index 1f08502ca44..d89679c5084 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/MathOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/MathOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.math.Abs; @@ -289,7 +288,7 @@ public Angle angle(Operand input) { * @param Tout * @return a new instance of Angle */ - public Angle angle(Operand input, DataType Tout) { + public Angle angle(Operand input, Class Tout) { return Angle.create(scope, input, Tout); } @@ -357,7 +356,7 @@ public ArgMax argMax(Operand inp * @return a new instance of ArgMax */ public ArgMax argMax(Operand input, - Operand dimension, DataType outputType) { + Operand dimension, Class outputType) { return ArgMax.create(scope, input, dimension, outputType); } @@ -412,7 +411,7 @@ public ArgMin argMin(Operand inp * @return a new instance of ArgMin */ public ArgMin argMin(Operand input, - Operand dimension, DataType outputType) { + Operand dimension, Class outputType) { return ArgMin.create(scope, input, dimension, outputType); } @@ -651,7 +650,7 @@ public ComplexAbs complexAbs(Operand x) { * @return a new instance of ComplexAbs */ public ComplexAbs complexAbs(Operand x, - DataType Tout) { + Class Tout) { return ComplexAbs.create(scope, x, Tout); } @@ -1178,7 +1177,7 @@ public Imag imag(Operand input) { * @param Tout * @return a new instance of Imag */ - public Imag imag(Operand input, DataType Tout) { + public Imag imag(Operand input, Class Tout) { return Imag.create(scope, input, Tout); } @@ -1635,7 +1634,7 @@ public Pow pow(Operand x, Operand y) { */ public QuantizedAdd quantizedAdd( Operand x, Operand y, Operand minX, Operand maxX, - Operand minY, Operand maxY, DataType Toutput) { + Operand minY, Operand maxY, Class Toutput) { return QuantizedAdd.create(scope, x, y, minX, maxX, minY, maxY, Toutput); } @@ -1654,7 +1653,7 @@ public QuantizedAdd quant */ public QuantizedMul quantizedMul( Operand x, Operand y, Operand minX, Operand maxX, - Operand minY, Operand maxY, DataType Toutput) { + Operand minY, Operand maxY, Class Toutput) { return QuantizedMul.create(scope, x, y, minX, maxX, minY, maxY, Toutput); } @@ -1699,7 +1698,7 @@ public Real real(Operand input) { * @param Tout * @return a new instance of Real */ - public Real real(Operand input, DataType Tout) { + public Real real(Operand input, Class Tout) { return Real.create(scope, input, Tout); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java index 81a24514a08..0f523cf9dfb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java @@ -18,7 +18,6 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.nn.AvgPool; import org.tensorflow.op.nn.AvgPool3d; @@ -642,8 +641,8 @@ public CudnnRNNParamsToCanonical cudnnRNNParamsToCanonica * @return a new instance of CudnnRnnParamsSize */ public CudnnRnnParamsSize cudnnRnnParamsSize( - Operand numLayers, Operand numUnits, Operand inputSize, DataType T, - DataType S, CudnnRnnParamsSize.Options... options) { + Operand numLayers, Operand numUnits, Operand inputSize, Class T, + Class S, CudnnRnnParamsSize.Options... options) { return CudnnRnnParamsSize.create(scope, numLayers, numUnits, inputSize, T, S, options); } @@ -1501,7 +1500,7 @@ public MaxPoolWithArgmax maxPoolWithArgmax(Operan * @return a new instance of MaxPoolWithArgmax */ public MaxPoolWithArgmax maxPoolWithArgmax( - Operand input, List ksize, List strides, DataType Targmax, String padding, + Operand input, List ksize, List strides, Class Targmax, String padding, MaxPoolWithArgmax.Options... options) { return MaxPoolWithArgmax.create(scope, input, ksize, strides, Targmax, padding, options); } @@ -1588,7 +1587,7 @@ public QuantizedBatchNormWithGlobalNormalizat Operand t, Operand tMin, Operand tMax, Operand m, Operand mMin, Operand mMax, Operand v, Operand vMin, Operand vMax, Operand beta, Operand betaMin, Operand betaMax, - Operand gamma, Operand gammaMin, Operand gammaMax, DataType outType, + Operand gamma, Operand gammaMin, Operand gammaMax, Class outType, Float varianceEpsilon, Boolean scaleAfterNormalization) { return QuantizedBatchNormWithGlobalNormalization.create(scope, t, tMin, tMax, m, mMin, mMax, v, vMin, vMax, beta, betaMin, betaMax, gamma, gammaMin, gammaMax, outType, varianceEpsilon, scaleAfterNormalization); } @@ -1610,7 +1609,7 @@ public QuantizedBatchNormWithGlobalNormalizat */ public QuantizedBiasAdd quantizedBiasAdd( Operand input, Operand bias, Operand minInput, Operand maxInput, - Operand minBias, Operand maxBias, DataType outType) { + Operand minBias, Operand maxBias, Class outType) { return QuantizedBiasAdd.create(scope, input, bias, minInput, maxInput, minBias, maxBias, outType); } @@ -1638,7 +1637,7 @@ public QuantizedBiasAdd q */ public QuantizedConv2d quantizedConv2d( Operand input, Operand filter, Operand minInput, Operand maxInput, - Operand minFilter, Operand maxFilter, DataType outType, + Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, QuantizedConv2d.Options... options) { return QuantizedConv2d.create(scope, input, filter, minInput, maxInput, minFilter, maxFilter, outType, strides, padding, options); } @@ -1689,7 +1688,7 @@ public QuantizedMaxPool quantizedMaxPool(Operand input, * @return a new instance of QuantizedRelu */ public QuantizedRelu quantizedRelu(Operand features, - Operand minFeatures, Operand maxFeatures, DataType outType) { + Operand minFeatures, Operand maxFeatures, Class outType) { return QuantizedRelu.create(scope, features, minFeatures, maxFeatures, outType); } @@ -1704,7 +1703,7 @@ public QuantizedRelu quantizedRelu(Operand * @return a new instance of QuantizedRelu6 */ public QuantizedRelu6 quantizedRelu6(Operand features, - Operand minFeatures, Operand maxFeatures, DataType outType) { + Operand minFeatures, Operand maxFeatures, Class outType) { return QuantizedRelu6.create(scope, features, minFeatures, maxFeatures, outType); } @@ -1721,7 +1720,7 @@ public QuantizedRelu6 quantizedRelu6(Opera */ public QuantizedReluX quantizedReluX(Operand features, Operand maxValue, Operand minFeatures, Operand maxFeatures, - DataType outType) { + Class outType) { return QuantizedReluX.create(scope, features, maxValue, minFeatures, maxFeatures, outType); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 8c7aa1c0408..f1abd40eb39 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -19,11 +19,9 @@ import java.nio.charset.Charset; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.EagerSession; import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operand; -import org.tensorflow.Tensor; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.ByteNdArray; import org.tensorflow.ndarray.DoubleNdArray; @@ -243,8 +241,6 @@ import org.tensorflow.op.core.TensorListSetItem; import org.tensorflow.op.core.TensorListSplit; import org.tensorflow.op.core.TensorListStack; -import org.tensorflow.op.core.TensorScatterMax; -import org.tensorflow.op.core.TensorScatterMin; import org.tensorflow.op.core.TensorScatterNdAdd; import org.tensorflow.op.core.TensorScatterNdMax; import org.tensorflow.op.core.TensorScatterNdMin; @@ -647,7 +643,7 @@ public AssignVariableOp assignVariableOp(Operand resource, * @param options carries optional attributes values * @return a new instance of Barrier */ - public Barrier barrier(List> componentTypes, Barrier.Options... options) { + public Barrier barrier(List> componentTypes, Barrier.Options... options) { return Barrier.create(scope, componentTypes, options); } @@ -728,7 +724,7 @@ public BarrierReadySize barrierReadySize(Operand handle) { * @return a new instance of BarrierTakeMany */ public BarrierTakeMany barrierTakeMany(Operand handle, Operand numElements, - List> componentTypes, BarrierTakeMany.Options... options) { + List> componentTypes, BarrierTakeMany.Options... options) { return BarrierTakeMany.create(scope, handle, numElements, componentTypes, options); } @@ -988,7 +984,7 @@ public BatchToSpaceNd * @param type * @return a new instance of Bitcast */ - public Bitcast bitcast(Operand input, DataType type) { + public Bitcast bitcast(Operand input, Class type) { return Bitcast.create(scope, input, type); } @@ -1713,7 +1709,7 @@ public Constant constant(Shape shape) { * @param tensor a Tensor holding the constant value * @return a constant of the same data type as `tensor` */ - public Constant constant(Tensor tensor) { + public Constant constant(T tensor) { return Constant.create(scope, tensor); } @@ -1865,15 +1861,14 @@ public Constant constant(Charset charset, Shape shape, DataBuffer Constant constant(DataType type, Shape shape, - ByteDataBuffer data) { + public Constant constant(Class type, Shape shape, ByteDataBuffer data) { return Constant.tensorOf(scope, type, shape, data); } @@ -2134,7 +2129,7 @@ public EditDistance editDistance(Operand hypothesisInd * @param options carries optional attributes values * @return a new instance of Empty */ - public Empty empty(Operand shape, DataType dtype, + public Empty empty(Operand shape, Class dtype, Empty.Options... options) { return Empty.create(scope, shape, dtype, options); } @@ -2155,7 +2150,7 @@ public Empty empty(Operand shape, DataType dtype * @return a new instance of EmptyTensorList */ public EmptyTensorList emptyTensorList( - Operand elementShape, Operand maxNumElements, DataType elementDtype) { + Operand elementShape, Operand maxNumElements, Class elementDtype) { return EmptyTensorList.create(scope, elementShape, maxNumElements, elementDtype); } @@ -2487,7 +2482,7 @@ public GetSessionHandle getSessionHandle(Operand value) { * @return a new instance of GetSessionTensor */ public GetSessionTensor getSessionTensor(Operand handle, - DataType dtype) { + Class dtype) { return GetSessionTensor.create(scope, handle, dtype); } @@ -2567,8 +2562,8 @@ public GuaranteeConst guaranteeConst(Operand input) { * @param options carries optional attributes values * @return a new instance of HashTable */ - public HashTable hashTable(DataType keyDtype, - DataType valueDtype, HashTable.Options... options) { + public HashTable hashTable(Class keyDtype, + Class valueDtype, HashTable.Options... options) { return HashTable.create(scope, keyDtype, valueDtype, options); } @@ -2631,7 +2626,7 @@ public HistogramFixedWidth histogramFixedWidth(Opera * @return a new instance of HistogramFixedWidth */ public HistogramFixedWidth histogramFixedWidth( - Operand values, Operand valueRange, Operand nbins, DataType dtype) { + Operand values, Operand valueRange, Operand nbins, Class dtype) { return HistogramFixedWidth.create(scope, values, valueRange, nbins, dtype); } @@ -2681,7 +2676,7 @@ public IdentityN identityN(Iterable> input) { * NewReadOnlyMemoryRegionFromFile in tensorflow::Env. * @return a new instance of ImmutableConst */ - public ImmutableConst immutableConst(DataType dtype, Shape shape, + public ImmutableConst immutableConst(Class dtype, Shape shape, String memoryRegionName) { return ImmutableConst.create(scope, dtype, shape, memoryRegionName); } @@ -2876,7 +2871,7 @@ public IsVariableInitialized isVariableInitialized(Operand * @return a new instance of LookupTableExport */ public LookupTableExport lookupTableExport( - Operand tableHandle, DataType Tkeys, DataType Tvalues) { + Operand tableHandle, Class Tkeys, Class Tvalues) { return LookupTableExport.create(scope, tableHandle, Tkeys, Tvalues); } @@ -2962,7 +2957,7 @@ public LoopCond loopCond(Operand input) { * @param options carries optional attributes values * @return a new instance of MapClear */ - public MapClear mapClear(List> dtypes, MapClear.Options... options) { + public MapClear mapClear(List> dtypes, MapClear.Options... options) { return MapClear.create(scope, dtypes, options); } @@ -2973,7 +2968,7 @@ public MapClear mapClear(List> dtypes, MapClear.Options... options) * @param options carries optional attributes values * @return a new instance of MapIncompleteSize */ - public MapIncompleteSize mapIncompleteSize(List> dtypes, + public MapIncompleteSize mapIncompleteSize(List> dtypes, MapIncompleteSize.Options... options) { return MapIncompleteSize.create(scope, dtypes, options); } @@ -2990,8 +2985,8 @@ public MapIncompleteSize mapIncompleteSize(List> dtypes, * @param options carries optional attributes values * @return a new instance of MapPeek */ - public MapPeek mapPeek(Operand key, Operand indices, List> dtypes, - MapPeek.Options... options) { + public MapPeek mapPeek(Operand key, Operand indices, + List> dtypes, MapPeek.Options... options) { return MapPeek.create(scope, key, indices, dtypes, options); } @@ -3002,7 +2997,7 @@ public MapPeek mapPeek(Operand key, Operand indices, List> dtypes, MapSize.Options... options) { + public MapSize mapSize(List> dtypes, MapSize.Options... options) { return MapSize.create(scope, dtypes, options); } @@ -3018,7 +3013,8 @@ public MapSize mapSize(List> dtypes, MapSize.Options... options) { * @return a new instance of MapStage */ public MapStage mapStage(Operand key, Operand indices, - Iterable> values, List> dtypes, MapStage.Options... options) { + Iterable> values, List> dtypes, + MapStage.Options... options) { return MapStage.create(scope, key, indices, values, dtypes, options); } @@ -3035,7 +3031,7 @@ public MapStage mapStage(Operand key, Operand indices, * @return a new instance of MapUnstage */ public MapUnstage mapUnstage(Operand key, Operand indices, - List> dtypes, MapUnstage.Options... options) { + List> dtypes, MapUnstage.Options... options) { return MapUnstage.create(scope, key, indices, dtypes, options); } @@ -3050,8 +3046,8 @@ public MapUnstage mapUnstage(Operand key, Operand indices, * @param options carries optional attributes values * @return a new instance of MapUnstageNoKey */ - public MapUnstageNoKey mapUnstageNoKey(Operand indices, List> dtypes, - MapUnstageNoKey.Options... options) { + public MapUnstageNoKey mapUnstageNoKey(Operand indices, + List> dtypes, MapUnstageNoKey.Options... options) { return MapUnstageNoKey.create(scope, indices, dtypes, options); } @@ -3192,7 +3188,7 @@ public MirrorPad mirrorPad(Operand in * @return a new instance of MlirPassthroughOp */ public MlirPassthroughOp mlirPassthroughOp(Iterable> inputs, String mlirModule, - List> Toutputs) { + List> Toutputs) { return MlirPassthroughOp.create(scope, inputs, mlirModule, Toutputs); } @@ -3214,7 +3210,7 @@ public MlirPassthroughOp mlirPassthroughOp(Iterable> inputs, String m * @return a new instance of MutableDenseHashTable */ public MutableDenseHashTable mutableDenseHashTable( - Operand emptyKey, Operand deletedKey, DataType valueDtype, + Operand emptyKey, Operand deletedKey, Class valueDtype, MutableDenseHashTable.Options... options) { return MutableDenseHashTable.create(scope, emptyKey, deletedKey, valueDtype, options); } @@ -3231,8 +3227,8 @@ public MutableDenseHashTable mutableDenseHash * @param options carries optional attributes values * @return a new instance of MutableHashTable */ - public MutableHashTable mutableHashTable(DataType keyDtype, - DataType valueDtype, MutableHashTable.Options... options) { + public MutableHashTable mutableHashTable(Class keyDtype, + Class valueDtype, MutableHashTable.Options... options) { return MutableHashTable.create(scope, keyDtype, valueDtype, options); } @@ -3249,7 +3245,7 @@ public MutableHashTable mutableHashTable(Data * @return a new instance of MutableHashTableOfTensors */ public MutableHashTableOfTensors mutableHashTableOfTensors( - DataType keyDtype, DataType valueDtype, MutableHashTableOfTensors.Options... options) { + Class keyDtype, Class valueDtype, MutableHashTableOfTensors.Options... options) { return MutableHashTableOfTensors.create(scope, keyDtype, valueDtype, options); } @@ -3443,7 +3439,7 @@ public OnesLike onesLike(Operand x) { * @param options carries optional attributes values * @return a new instance of OrderedMapClear */ - public OrderedMapClear orderedMapClear(List> dtypes, + public OrderedMapClear orderedMapClear(List> dtypes, OrderedMapClear.Options... options) { return OrderedMapClear.create(scope, dtypes, options); } @@ -3455,7 +3451,7 @@ public OrderedMapClear orderedMapClear(List> dtypes, * @param options carries optional attributes values * @return a new instance of OrderedMapIncompleteSize */ - public OrderedMapIncompleteSize orderedMapIncompleteSize(List> dtypes, + public OrderedMapIncompleteSize orderedMapIncompleteSize(List> dtypes, OrderedMapIncompleteSize.Options... options) { return OrderedMapIncompleteSize.create(scope, dtypes, options); } @@ -3474,7 +3470,7 @@ public OrderedMapIncompleteSize orderedMapIncompleteSize(List> dtype * @return a new instance of OrderedMapPeek */ public OrderedMapPeek orderedMapPeek(Operand key, Operand indices, - List> dtypes, OrderedMapPeek.Options... options) { + List> dtypes, OrderedMapPeek.Options... options) { return OrderedMapPeek.create(scope, key, indices, dtypes, options); } @@ -3485,7 +3481,7 @@ public OrderedMapPeek orderedMapPeek(Operand key, Operand indice * @param options carries optional attributes values * @return a new instance of OrderedMapSize */ - public OrderedMapSize orderedMapSize(List> dtypes, + public OrderedMapSize orderedMapSize(List> dtypes, OrderedMapSize.Options... options) { return OrderedMapSize.create(scope, dtypes, options); } @@ -3504,7 +3500,8 @@ public OrderedMapSize orderedMapSize(List> dtypes, * @return a new instance of OrderedMapStage */ public OrderedMapStage orderedMapStage(Operand key, Operand indices, - Iterable> values, List> dtypes, OrderedMapStage.Options... options) { + Iterable> values, List> dtypes, + OrderedMapStage.Options... options) { return OrderedMapStage.create(scope, key, indices, values, dtypes, options); } @@ -3521,7 +3518,7 @@ public OrderedMapStage orderedMapStage(Operand key, Operand indi * @return a new instance of OrderedMapUnstage */ public OrderedMapUnstage orderedMapUnstage(Operand key, Operand indices, - List> dtypes, OrderedMapUnstage.Options... options) { + List> dtypes, OrderedMapUnstage.Options... options) { return OrderedMapUnstage.create(scope, key, indices, dtypes, options); } @@ -3537,7 +3534,7 @@ public OrderedMapUnstage orderedMapUnstage(Operand key, Operand * @return a new instance of OrderedMapUnstageNoKey */ public OrderedMapUnstageNoKey orderedMapUnstageNoKey(Operand indices, - List> dtypes, OrderedMapUnstageNoKey.Options... options) { + List> dtypes, OrderedMapUnstageNoKey.Options... options) { return OrderedMapUnstageNoKey.create(scope, indices, dtypes, options); } @@ -3688,7 +3685,7 @@ public ParallelDynamicStitch parallelDynamicStitch( * @param options carries optional attributes values * @return a new instance of Placeholder */ - public Placeholder placeholder(DataType dtype, + public Placeholder placeholder(Class dtype, Placeholder.Options... options) { return Placeholder.create(scope, dtype, options); } @@ -3817,8 +3814,7 @@ public Rank rank(Operand input) { * @param dtype the dtype of the value. * @return a new instance of ReadVariableOp */ - public ReadVariableOp readVariableOp(Operand resource, - DataType dtype) { + public ReadVariableOp readVariableOp(Operand resource, Class dtype) { return ReadVariableOp.create(scope, resource, dtype); } @@ -3999,7 +3995,7 @@ public RefSwitch refSwitch(Operand data, Operand * @return a new instance of RemoteFusedGraphExecute */ public RemoteFusedGraphExecute remoteFusedGraphExecute(Iterable> inputs, - List> Toutputs, String serializedRemoteFusedGraphExecuteInfo) { + List> Toutputs, String serializedRemoteFusedGraphExecuteInfo) { return RemoteFusedGraphExecute.create(scope, inputs, Toutputs, serializedRemoteFusedGraphExecuteInfo); } @@ -4086,7 +4082,7 @@ public Reshape reshape(Operand tensor * @return a new instance of ResourceCountUpTo */ public ResourceCountUpTo resourceCountUpTo(Operand resource, Long limit, - DataType T) { + Class T) { return ResourceCountUpTo.create(scope, resource, limit, T); } @@ -4114,7 +4110,7 @@ public ResourceCountUpTo resourceCountUpTo(Operand res * @return a new instance of ResourceGather */ public ResourceGather resourceGather(Operand resource, - Operand indices, DataType dtype, ResourceGather.Options... options) { + Operand indices, Class dtype, ResourceGather.Options... options) { return ResourceGather.create(scope, resource, indices, dtype, options); } @@ -4127,7 +4123,7 @@ public ResourceGather resourceGather(Ope * @return a new instance of ResourceGatherNd */ public ResourceGatherNd resourceGatherNd( - Operand resource, Operand indices, DataType dtype) { + Operand resource, Operand indices, Class dtype) { return ResourceGatherNd.create(scope, resource, indices, dtype); } @@ -5406,7 +5402,7 @@ public SetDiff1d setDiff1d(Operand x, Operand * @return a new instance of SetDiff1d */ public SetDiff1d setDiff1d(Operand x, Operand y, - DataType outIdx) { + Class outIdx) { return SetDiff1d.create(scope, x, y, outIdx); } @@ -5467,7 +5463,7 @@ public org.tensorflow.op.core.Shape shape(Operand i * @return a new instance of Shape */ public org.tensorflow.op.core.Shape shape( - Operand input, DataType outType) { + Operand input, Class outType) { return org.tensorflow.op.core.Shape.create(scope, input, outType); } @@ -5495,7 +5491,7 @@ public ShapeN shapeN(Iterable> input) { * @return a new instance of ShapeN */ public ShapeN shapeN(Iterable> input, - DataType outType) { + Class outType) { return ShapeN.create(scope, input, outType); } @@ -5536,7 +5532,7 @@ public Size size(Operand input) { * @param outType * @return a new instance of Size */ - public Size size(Operand input, DataType outType) { + public Size size(Operand input, Class outType) { return Size.create(scope, input, outType); } @@ -5816,7 +5812,7 @@ public Stage stage(Iterable> values, Stage.Options... options) { * @param options carries optional attributes values * @return a new instance of StageClear */ - public StageClear stageClear(List> dtypes, StageClear.Options... options) { + public StageClear stageClear(List> dtypes, StageClear.Options... options) { return StageClear.create(scope, dtypes, options); } @@ -5832,7 +5828,7 @@ public StageClear stageClear(List> dtypes, StageClear.Options... opt * @param options carries optional attributes values * @return a new instance of StagePeek */ - public StagePeek stagePeek(Operand index, List> dtypes, + public StagePeek stagePeek(Operand index, List> dtypes, StagePeek.Options... options) { return StagePeek.create(scope, index, dtypes, options); } @@ -5844,7 +5840,7 @@ public StagePeek stagePeek(Operand index, List> dtypes, * @param options carries optional attributes values * @return a new instance of StageSize */ - public StageSize stageSize(List> dtypes, StageSize.Options... options) { + public StageSize stageSize(List> dtypes, StageSize.Options... options) { return StageSize.create(scope, dtypes, options); } @@ -6109,7 +6105,7 @@ public SwitchCond switchCond(Operand data, Operand TemporaryVariable temporaryVariable(Shape shape, DataType dtype, + public TemporaryVariable temporaryVariable(Shape shape, Class dtype, TemporaryVariable.Options... options) { return TemporaryVariable.create(scope, shape, dtype, options); } @@ -6124,7 +6120,7 @@ public TemporaryVariable temporaryVariable(Shape shape, Dat * @param options carries optional attributes values * @return a new instance of TensorArray */ - public TensorArray tensorArray(Operand size, DataType dtype, + public TensorArray tensorArray(Operand size, Class dtype, TensorArray.Options... options) { return TensorArray.create(scope, size, dtype, options); } @@ -6164,7 +6160,7 @@ public TensorArrayClose tensorArrayClose(Operand handle) { * @return a new instance of TensorArrayConcat */ public TensorArrayConcat tensorArrayConcat(Operand handle, - Operand flowIn, DataType dtype, TensorArrayConcat.Options... options) { + Operand flowIn, Class dtype, TensorArrayConcat.Options... options) { return TensorArrayConcat.create(scope, handle, flowIn, dtype, options); } @@ -6182,7 +6178,7 @@ public TensorArrayConcat tensorArrayConcat(Operand handl * @return a new instance of TensorArrayGather */ public TensorArrayGather tensorArrayGather(Operand handle, - Operand indices, Operand flowIn, DataType dtype, + Operand indices, Operand flowIn, Class dtype, TensorArrayGather.Options... options) { return TensorArrayGather.create(scope, handle, indices, flowIn, dtype, options); } @@ -6270,7 +6266,7 @@ public TensorArrayGradWithShape tensorArrayGradWithShape(Operand handle, * @return a new instance of TensorArrayPack */ public TensorArrayPack tensorArrayPack(Operand handle, - Operand flowIn, DataType dtype, TensorArrayPack.Options... options) { + Operand flowIn, Class dtype, TensorArrayPack.Options... options) { return TensorArrayPack.create(scope, handle, flowIn, dtype, options); } @@ -6285,7 +6281,7 @@ public TensorArrayPack tensorArrayPack(Operand han * @return a new instance of TensorArrayRead */ public TensorArrayRead tensorArrayRead(Operand handle, - Operand index, Operand flowIn, DataType dtype) { + Operand index, Operand flowIn, Class dtype) { return TensorArrayRead.create(scope, handle, index, flowIn, dtype); } @@ -6402,7 +6398,7 @@ public TensorArrayWrite tensorArrayWrite(Operand handle, */ public TensorListConcat tensorListConcat( Operand inputHandle, Operand elementShape, Operand leadingDims, - DataType elementDtype) { + Class elementDtype) { return TensorListConcat.create(scope, inputHandle, elementShape, leadingDims, elementDtype); } @@ -6414,7 +6410,7 @@ public TensorListConcat tensorListConcat * @return a new instance of TensorListConcatLists */ public TensorListConcatLists tensorListConcatLists(Operand inputA, - Operand inputB, DataType elementDtype) { + Operand inputB, Class elementDtype) { return TensorListConcatLists.create(scope, inputA, inputB, elementDtype); } @@ -6430,7 +6426,7 @@ public TensorListConcatLists tensorListConcatLists(Operand * @return a new instance of TensorListElementShape */ public TensorListElementShape tensorListElementShape( - Operand inputHandle, DataType shapeType) { + Operand inputHandle, Class shapeType) { return TensorListElementShape.create(scope, inputHandle, shapeType); } @@ -6469,7 +6465,7 @@ public TensorListFromTensor tensorListFromT * @return a new instance of TensorListGather */ public TensorListGather tensorListGather(Operand inputHandle, - Operand indices, Operand elementShape, DataType elementDtype) { + Operand indices, Operand elementShape, Class elementDtype) { return TensorListGather.create(scope, inputHandle, indices, elementShape, elementDtype); } @@ -6483,7 +6479,7 @@ public TensorListGather tensorListGather(Operand inputHa * @return a new instance of TensorListGetItem */ public TensorListGetItem tensorListGetItem(Operand inputHandle, - Operand index, Operand elementShape, DataType elementDtype) { + Operand index, Operand elementShape, Class elementDtype) { return TensorListGetItem.create(scope, inputHandle, index, elementShape, elementDtype); } @@ -6517,7 +6513,7 @@ public TensorListLength tensorListLength(Operand inputHandle) { * @return a new instance of TensorListPopBack */ public TensorListPopBack tensorListPopBack(Operand inputHandle, - Operand elementShape, DataType elementDtype) { + Operand elementShape, Class elementDtype) { return TensorListPopBack.create(scope, inputHandle, elementShape, elementDtype); } @@ -6564,7 +6560,7 @@ public TensorListPushBackBatch tensorListPushBackBatch(Operand * @return a new instance of TensorListReserve */ public TensorListReserve tensorListReserve( - Operand elementShape, Operand numElements, DataType elementDtype) { + Operand elementShape, Operand numElements, Class elementDtype) { return TensorListReserve.create(scope, elementShape, numElements, elementDtype); } @@ -6680,36 +6676,10 @@ public TensorListSplit tensorListSplit(Oper * @return a new instance of TensorListStack */ public TensorListStack tensorListStack(Operand inputHandle, - Operand elementShape, DataType elementDtype, TensorListStack.Options... options) { + Operand elementShape, Class elementDtype, TensorListStack.Options... options) { return TensorListStack.create(scope, inputHandle, elementShape, elementDtype, options); } - /** - * - * @param data type for {@code output()} output - * @param tensor Tensor to update. - * @param indices Index tensor. - * @param updates Updates to scatter into output. - * @return a new instance of TensorScatterMax - */ - public TensorScatterMax tensorScatterMax( - Operand tensor, Operand indices, Operand updates) { - return TensorScatterMax.create(scope, tensor, indices, updates); - } - - /** - * - * @param data type for {@code output()} output - * @param tensor Tensor to update. - * @param indices Index tensor. - * @param updates Updates to scatter into output. - * @return a new instance of TensorScatterMin - */ - public TensorScatterMin tensorScatterMin( - Operand tensor, Operand indices, Operand updates) { - return TensorScatterMin.create(scope, tensor, indices, updates); - } - /** * Adds sparse `updates` to an existing tensor according to `indices`. *

@@ -7287,7 +7257,7 @@ public Unique unique(Operand * @return a new instance of Unique */ public Unique unique(Operand x, - Operand axis, DataType outIdx) { + Operand axis, Class outIdx) { return Unique.create(scope, x, axis, outIdx); } @@ -7404,7 +7374,7 @@ public UniqueWithCounts uniqueWi * @return a new instance of UniqueWithCounts */ public UniqueWithCounts uniqueWithCounts( - Operand x, Operand axis, DataType outIdx) { + Operand x, Operand axis, Class outIdx) { return UniqueWithCounts.create(scope, x, axis, outIdx); } @@ -7477,7 +7447,7 @@ public Unstack unstack(Operand value, Long num, * @param options carries optional attributes values * @return a new instance of Unstage */ - public Unstage unstage(List> dtypes, Unstage.Options... options) { + public Unstage unstage(List> dtypes, Unstage.Options... options) { return Unstage.create(scope, dtypes, options); } @@ -7490,7 +7460,7 @@ public Unstage unstage(List> dtypes, Unstage.Options... options) { * @param options carries optional attributes values * @return a new instance of VarHandleOp */ - public VarHandleOp varHandleOp(DataType dtype, Shape shape, + public VarHandleOp varHandleOp(Class dtype, Shape shape, VarHandleOp.Options... options) { return VarHandleOp.create(scope, dtype, shape, options); } @@ -7533,7 +7503,7 @@ public Variable variable(Operand init, Variable.Options. * @param options carries optional attributes values * @return a new instance of Variable */ - public Variable variable(Shape shape, DataType dtype, + public Variable variable(Shape shape, Class dtype, Variable.Options... options) { return Variable.create(scope, shape, dtype, options); } @@ -7573,7 +7543,7 @@ public VariableShape variableShape(Operand input) { * @param outType * @return a new instance of VariableShape */ - public VariableShape variableShape(Operand input, DataType outType) { + public VariableShape variableShape(Operand input, Class outType) { return VariableShape.create(scope, input, outType); } @@ -7691,7 +7661,7 @@ public XlaSpmdShardToFullShape xlaSpmdShardToFullShape(Oper * @return a constant tensor initialized with zeros * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros. */ - public Zeros zeros(Operand dims, DataType type) { + public Zeros zeros(Operand dims, Class type) { return Zeros.create(scope, dims, type); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/QuantizationOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/QuantizationOps.java index aec0d667c65..045ed3d4b03 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/QuantizationOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/QuantizationOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.quantization.Dequantize; import org.tensorflow.op.quantization.FakeQuantWithMinMaxArgs; @@ -173,7 +172,7 @@ public Dequantize dequantize(Operand input, * @return a new instance of Dequantize */ public Dequantize dequantize(Operand input, - Operand minRange, Operand maxRange, DataType dtype, + Operand minRange, Operand maxRange, Class dtype, Dequantize.Options... options) { return Dequantize.create(scope, input, minRange, maxRange, dtype, options); } @@ -504,7 +503,7 @@ public FakeQuantWithMinMaxVarsPerChannelGradient fakeQuantWithMinMaxVarsPerChann * @return a new instance of Quantize */ public Quantize quantize(Operand input, Operand minRange, - Operand maxRange, DataType T, Quantize.Options... options) { + Operand maxRange, Class T, Quantize.Options... options) { return Quantize.create(scope, input, minRange, maxRange, T, options); } @@ -562,8 +561,7 @@ public QuantizeAndDequantize quantizeAndDequantize(Operan * @return a new instance of QuantizeDownAndShrinkRange */ public QuantizeDownAndShrinkRange quantizeDownAndShrinkRange( - Operand input, Operand inputMin, Operand inputMax, - DataType outType) { + Operand input, Operand inputMin, Operand inputMax, Class outType) { return QuantizeDownAndShrinkRange.create(scope, input, inputMin, inputMax, outType); } @@ -625,7 +623,7 @@ public RequantizationRange requantizationRange(Operand inpu */ public Requantize requantize(Operand input, Operand inputMin, Operand inputMax, Operand requestedOutputMin, - Operand requestedOutputMax, DataType outType) { + Operand requestedOutputMax, Class outType) { return Requantize.create(scope, input, inputMin, inputMax, requestedOutputMin, requestedOutputMax, outType); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/RandomOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/RandomOps.java index 071e77c7a70..5fe5b6268fe 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/RandomOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/RandomOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.random.AllCandidateSampler; import org.tensorflow.op.random.LogUniformCandidateSampler; @@ -141,7 +140,7 @@ public Multinomial multinomial(Operand logits, * @return a new instance of Multinomial */ public Multinomial multinomial(Operand logits, - Operand numSamples, DataType outputDtype, Multinomial.Options... options) { + Operand numSamples, Class outputDtype, Multinomial.Options... options) { return Multinomial.create(scope, logits, numSamples, outputDtype, options); } @@ -236,7 +235,7 @@ public RandomPoisson randomPoisso * @return a new instance of RandomPoisson */ public RandomPoisson randomPoisson( - Operand shape, Operand rate, DataType dtype, RandomPoisson.Options... options) { + Operand shape, Operand rate, Class dtype, RandomPoisson.Options... options) { return RandomPoisson.create(scope, shape, rate, dtype, options); } @@ -274,7 +273,7 @@ public RandomShuffle randomShuffle(Operand value, * @return a new instance of RandomStandardNormal */ public RandomStandardNormal randomStandardNormal( - Operand shape, DataType dtype, RandomStandardNormal.Options... options) { + Operand shape, Class dtype, RandomStandardNormal.Options... options) { return RandomStandardNormal.create(scope, shape, dtype, options); } @@ -291,7 +290,7 @@ public RandomStandardNormal randomStan * @return a new instance of RandomUniform */ public RandomUniform randomUniform(Operand shape, - DataType dtype, RandomUniform.Options... options) { + Class dtype, RandomUniform.Options... options) { return RandomUniform.create(scope, shape, dtype, options); } @@ -358,7 +357,7 @@ public StatefulRandomBinomial sta */ public StatefulRandomBinomial statefulRandomBinomial( Operand resource, Operand algorithm, Operand shape, Operand counts, - Operand probs, DataType dtype) { + Operand probs, Class dtype) { return StatefulRandomBinomial.create(scope, resource, algorithm, shape, counts, probs, dtype); } @@ -391,7 +390,7 @@ public StatefulStandardNormal statefulStandardNormal * @return a new instance of StatefulStandardNormal */ public StatefulStandardNormal statefulStandardNormal( - Operand resource, Operand algorithm, Operand shape, DataType dtype) { + Operand resource, Operand algorithm, Operand shape, Class dtype) { return StatefulStandardNormal.create(scope, resource, algorithm, shape, dtype); } @@ -422,7 +421,7 @@ public StatelessMultinomial state * @return a new instance of StatelessMultinomial */ public StatelessMultinomial statelessMultinomial( - Operand logits, Operand numSamples, Operand seed, DataType outputDtype) { + Operand logits, Operand numSamples, Operand seed, Class outputDtype) { return StatelessMultinomial.create(scope, logits, numSamples, seed, outputDtype); } @@ -457,7 +456,7 @@ public StatelessRandomNormal st * @return a new instance of StatelessRandomNormal */ public StatelessRandomNormal statelessRandomNormal( - Operand shape, Operand seed, DataType dtype) { + Operand shape, Operand seed, Class dtype) { return StatelessRandomNormal.create(scope, shape, seed, dtype); } @@ -494,7 +493,7 @@ public StatelessRandomUniform s * @return a new instance of StatelessRandomUniform */ public StatelessRandomUniform statelessRandomUniform( - Operand shape, Operand seed, DataType dtype) { + Operand shape, Operand seed, Class dtype) { return StatelessRandomUniform.create(scope, shape, seed, dtype); } @@ -533,7 +532,7 @@ public StatelessTruncatedNormal * @return a new instance of StatelessTruncatedNormal */ public StatelessTruncatedNormal statelessTruncatedNormal( - Operand shape, Operand seed, DataType dtype) { + Operand shape, Operand seed, Class dtype) { return StatelessTruncatedNormal.create(scope, shape, seed, dtype); } @@ -551,7 +550,7 @@ public StatelessTrunca * @return a new instance of TruncatedNormal */ public TruncatedNormal truncatedNormal(Operand shape, - DataType dtype, TruncatedNormal.Options... options) { + Class dtype, TruncatedNormal.Options... options) { return TruncatedNormal.create(scope, shape, dtype, options); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ShapeOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ShapeOps.java index 81c692571f1..3683805eb2e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ShapeOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ShapeOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.core.Shape; import org.tensorflow.op.core.Shapes; @@ -116,7 +115,7 @@ public Operand flatten(Shape shape) { * @return the reshaped operand */ public Operand flatten(Operand operand, - DataType dType) { + Class dType) { return Shapes.flatten(scope, operand, dType); } @@ -129,7 +128,7 @@ public Operand flatten(Operand operan * @param dType the shape datatype * @return the flattened shape */ - public Operand flatten(Shape shape, DataType dType) { + public Operand flatten(Shape shape, Class dType) { return Shapes.flatten(scope, shape, dType); } @@ -153,7 +152,7 @@ public Operand head(Shape shape) { * @param the shape datatype. * @return a 1-dimensional Operand containing the Shape's first dimension */ - public Operand head(Shape shape, DataType dType) { + public Operand head(Shape shape, Class dType) { return Shapes.head(scope, shape, dType); } @@ -177,7 +176,7 @@ public Operand numDimensions(Shape shape) { * @param dType the shape datatype * @return the number of dimensions */ - public Operand numDimensions(Shape shape, DataType dType) { + public Operand numDimensions(Shape shape, Class dType) { return Shapes.numDimensions(scope, shape, dType); } @@ -262,7 +261,7 @@ public Operand reduceDims(Shape shape, Operand axis) { * @return the reshaped operand */ public Operand reduceDims(Operand operand, - Operand axis, DataType dType) { + Operand axis, Class dType) { return Shapes.reduceDims(scope, operand, axis, dType); } @@ -277,7 +276,7 @@ public Operand reduceDims(Operand ope * @return the reduced shape */ public Operand reduceDims(Shape shape, Operand axis, - DataType dType) { + Class dType) { return Shapes.reduceDims(scope, shape, axis, dType); } @@ -313,7 +312,7 @@ public Operand size(Operand input, Operand * @param dType the shape datatype * @return the size */ - public Operand size(Shape shape, DataType dType) { + public Operand size(Shape shape, Class dType) { return Shapes.size(scope, shape, dType); } @@ -340,7 +339,7 @@ public Operand size(Shape shape, Operand dim) { * @return the size of the specified dimension */ public Operand size(Operand input, Operand dim, - DataType dType) { + Class dType) { return Shapes.size(scope, input, dim, dType); } @@ -354,7 +353,7 @@ public Operand size(Operand input, Op * @param dType the shape datatype * @return the size of the specified dimension */ - public Operand size(Shape shape, Operand dim, DataType dType) { + public Operand size(Shape shape, Operand dim, Class dType) { return Shapes.size(scope, shape, dim, dType); } @@ -378,7 +377,7 @@ public Operand squeeze(Shape shape) { * @param dType the shape datatype. * @return the squeezed shape */ - public Operand squeeze(Shape shape, DataType dType) { + public Operand squeeze(Shape shape, Class dType) { return Shapes.squeeze(scope, shape, dType); } @@ -406,7 +405,7 @@ public Operand tail(Shape shape) { * @return a 1-dimensional Operand that contains the dimension matching the last dimension of the * Shape */ - public Operand tail(Shape shape, DataType dType) { + public Operand tail(Shape shape, Class dType) { return Shapes.tail(scope, shape, dType); } @@ -436,7 +435,7 @@ public Operand take(Shape shape, Operand n) { * @return a 1-dimensional operand with the dimensions matching * the first n dimensions of the * shape */ - public Operand take(Shape shape, Operand n, DataType dType) { + public Operand take(Shape shape, Operand n, Class dType) { return Shapes.take(scope, shape, n, dType); } @@ -466,7 +465,7 @@ public Operand takeLast(Shape shape, Operand * @return a 1-dimensional operand containing the dimensions matching the last n dimensions of the * shape */ - public Operand takeLast(Shape shape, Operand n, DataType dType) { + public Operand takeLast(Shape shape, Operand n, Class dType) { return Shapes.takeLast(scope, shape, n, dType); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SignalOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SignalOps.java index f4ec7bdb48d..1466dcd8de4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SignalOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SignalOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.signal.BatchFft; import org.tensorflow.op.signal.BatchFft2d; @@ -242,7 +241,7 @@ public Irfft irfft(Operand input, Operand * @return a new instance of Irfft */ public Irfft irfft(Operand input, - Operand fftLength, DataType Treal) { + Operand fftLength, Class Treal) { return Irfft.create(scope, input, fftLength, Treal); } @@ -298,7 +297,7 @@ public Irfft2d irfft2d(Operand input, Operand Irfft2d irfft2d(Operand input, - Operand fftLength, DataType Treal) { + Operand fftLength, Class Treal) { return Irfft2d.create(scope, input, fftLength, Treal); } @@ -354,7 +353,7 @@ public Irfft3d irfft3d(Operand input, Operand Irfft3d irfft3d(Operand input, - Operand fftLength, DataType Treal) { + Operand fftLength, Class Treal) { return Irfft3d.create(scope, input, fftLength, Treal); } @@ -379,7 +378,7 @@ public Irfft3d irfft3d(Operand input, * @return a new instance of Rfft */ public Rfft rfft(Operand input, - Operand fftLength, DataType Tcomplex) { + Operand fftLength, Class Tcomplex) { return Rfft.create(scope, input, fftLength, Tcomplex); } @@ -405,7 +404,7 @@ public Rfft rfft(Operand input, * @return a new instance of Rfft2d */ public Rfft2d rfft2d(Operand input, - Operand fftLength, DataType Tcomplex) { + Operand fftLength, Class Tcomplex) { return Rfft2d.create(scope, input, fftLength, Tcomplex); } @@ -431,7 +430,7 @@ public Rfft2d rfft2d(Operand input, * @return a new instance of Rfft3d */ public Rfft3d rfft3d(Operand input, - Operand fftLength, DataType Tcomplex) { + Operand fftLength, Class Tcomplex) { return Rfft3d.create(scope, input, fftLength, Tcomplex); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SparseOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SparseOps.java index 42cdf9569d9..d7690dacdb4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SparseOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SparseOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.sparse.AddManySparseToTensorsMap; @@ -269,7 +268,7 @@ public DenseToSparseSetOperation denseToSparseSetOperation( * @return a new instance of DeserializeSparse */ public DeserializeSparse deserializeSparse( - Operand serializedSparse, DataType dtype) { + Operand serializedSparse, Class dtype) { return DeserializeSparse.create(scope, serializedSparse, dtype); } @@ -315,7 +314,7 @@ public SparseAccumulatorApplyGradient sparseAccumulatorApplyGr * @return a new instance of SparseAccumulatorTakeGradient */ public SparseAccumulatorTakeGradient sparseAccumulatorTakeGradient( - Operand handle, Operand numRequired, DataType dtype) { + Operand handle, Operand numRequired, Class dtype) { return SparseAccumulatorTakeGradient.create(scope, handle, numRequired, dtype); } @@ -476,8 +475,8 @@ public SparseConcat sparseConcat(Iterable> * @param options carries optional attributes values * @return a new instance of SparseConditionalAccumulator */ - public SparseConditionalAccumulator sparseConditionalAccumulator( - DataType dtype, Shape shape, SparseConditionalAccumulator.Options... options) { + public SparseConditionalAccumulator sparseConditionalAccumulator(Class dtype, + Shape shape, SparseConditionalAccumulator.Options... options) { return SparseConditionalAccumulator.create(scope, dtype, shape, options); } @@ -1493,7 +1492,7 @@ public SparseToSparseSetOperation sparseToSparseSetOperatio * @return a new instance of TakeManySparseFromTensorsMap */ public TakeManySparseFromTensorsMap takeManySparseFromTensorsMap( - Operand sparseHandles, DataType dtype, + Operand sparseHandles, Class dtype, TakeManySparseFromTensorsMap.Options... options) { return TakeManySparseFromTensorsMap.create(scope, sparseHandles, dtype, options); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/StringsOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/StringsOps.java index f6491843332..a09884817f8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/StringsOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/StringsOps.java @@ -18,7 +18,6 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.strings.Join; import org.tensorflow.op.strings.Lower; @@ -483,8 +482,7 @@ public ToNumber toNumber(Operand stringTensor) { * @param outType The numeric type to interpret each string in `string_tensor` as. * @return a new instance of ToNumber */ - public ToNumber toNumber(Operand stringTensor, - DataType outType) { + public ToNumber toNumber(Operand stringTensor, Class outType) { return ToNumber.create(scope, stringTensor, outType); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TrainOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TrainOps.java index 2c5d8752136..1e253ac9956 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TrainOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TrainOps.java @@ -18,7 +18,6 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.train.AccumulatorApplyGradient; @@ -159,7 +158,7 @@ public AccumulatorSetGlobalStep accumulatorSetGlobalStep(Operand handle * @return a new instance of AccumulatorTakeGradient */ public AccumulatorTakeGradient accumulatorTakeGradient( - Operand handle, Operand numRequired, DataType dtype) { + Operand handle, Operand numRequired, Class dtype) { return AccumulatorTakeGradient.create(scope, handle, numRequired, dtype); } @@ -542,7 +541,7 @@ public BatchMatMul batchMatMul(Operand x, Operand y, * @param options carries optional attributes values * @return a new instance of ConditionalAccumulator */ - public ConditionalAccumulator conditionalAccumulator(DataType dtype, + public ConditionalAccumulator conditionalAccumulator(Class dtype, Shape shape, ConditionalAccumulator.Options... options) { return ConditionalAccumulator.create(scope, dtype, shape, options); } @@ -1295,7 +1294,7 @@ public ResourceSparseApplyRmsProp resourceS * @return a new instance of Restore */ public Restore restore(Operand prefix, Operand tensorNames, - Operand shapeAndSlices, List> dtypes) { + Operand shapeAndSlices, List> dtypes) { return Restore.create(scope, prefix, tensorNames, shapeAndSlices, dtypes); } @@ -1321,7 +1320,7 @@ public Restore restore(Operand prefix, Operand tensorNames, * @return a new instance of RestoreSlice */ public RestoreSlice restoreSlice(Operand filePattern, - Operand tensorName, Operand shapeAndSlice, DataType dt, + Operand tensorName, Operand shapeAndSlice, Class dt, RestoreSlice.Options... options) { return RestoreSlice.create(scope, filePattern, tensorName, shapeAndSlice, dt, options); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java index 535972d4883..6c16e3cb7ff 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.xla.BroadcastHelper; @@ -278,7 +277,7 @@ public Pad pad(Operand input, Operand * @param shape The shape of the tensor. * @return a new instance of Recv */ - public Recv recv(DataType dtype, String tensorName, Shape shape) { + public Recv recv(Class dtype, String tensorName, Shape shape) { return Recv.create(scope, dtype, tensorName, shape); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseAnd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseAnd.java index e199ff2201f..2e3d0cb2949 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseAnd.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseAnd.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Elementwise computes the bitwise AND of `x` and `y`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseOr.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseOr.java index 264c2bc340b..abeb509f5c7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseOr.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseOr.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Elementwise computes the bitwise OR of `x` and `y`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseXor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseXor.java index 1d8f668c175..4339d2d1df9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseXor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseXor.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Elementwise computes the bitwise XOR of `x` and `y`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/Invert.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/Invert.java index 9f8bdfd56d8..589f0047356 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/Invert.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/Invert.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Invert (flip) each bit of supported types; for example, type `uint8` value 01010101 becomes 10101010. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/LeftShift.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/LeftShift.java index f7a47534d81..69c818873e9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/LeftShift.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/LeftShift.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Elementwise computes the bitwise left-shift of `x` and `y`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/RightShift.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/RightShift.java index 99c5fe5766e..8106172ee4f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/RightShift.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/RightShift.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Elementwise computes the bitwise right-shift of `x` and `y`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/AllReduce.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/AllReduce.java index d58bc1357f4..0ef37a5a9d2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/AllReduce.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/AllReduce.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Mutually reduces multiple tensors of identical type and shape. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/BroadcastRecv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/BroadcastRecv.java index ab938852275..8603acc0e05 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/BroadcastRecv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/BroadcastRecv.java @@ -17,7 +17,6 @@ package org.tensorflow.op.collective; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -77,7 +76,7 @@ private Options() { * @return a new instance of BroadcastRecv */ @Endpoint(describeByClass = true) - public static BroadcastRecv create(Scope scope, DataType T, Long groupSize, Long groupKey, Long instanceKey, Shape shape, Options... options) { + public static BroadcastRecv create(Scope scope, Class T, Long groupSize, Long groupKey, Long instanceKey, Shape shape, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("CollectiveBcastRecv", scope.makeOpName("BroadcastRecv")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("T", T); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/All.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/All.java index 909427d1a57..8de82899104 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/All.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/All.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the "logical and" of elements across dimensions of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Any.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Any.java index 0316e5e1a94..f5e5f3753e2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Any.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Any.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the "logical or" of elements across dimensions of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Barrier.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Barrier.java index b9c5c84083f..dc3f8b781bc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Barrier.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Barrier.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,6 +28,7 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * Defines a barrier that persists across different graph executions. @@ -105,10 +105,10 @@ private Options() { * @return a new instance of Barrier */ @Endpoint(describeByClass = true) - public static Barrier create(Scope scope, List> componentTypes, Options... options) { + public static Barrier create(Scope scope, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Barrier", scope.makeOpName("Barrier")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; + Class[] componentTypesArray = new Class[componentTypes.size()]; for (int i = 0; i < componentTypesArray.length; ++i) { componentTypesArray[i] = componentTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierTakeMany.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierTakeMany.java index 6c391fab5fa..9032635571f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierTakeMany.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierTakeMany.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -31,6 +30,7 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * Takes the given number of completed elements from a barrier. @@ -98,12 +98,12 @@ private Options() { * @return a new instance of BarrierTakeMany */ @Endpoint(describeByClass = true) - public static BarrierTakeMany create(Scope scope, Operand handle, Operand numElements, List> componentTypes, Options... options) { + public static BarrierTakeMany create(Scope scope, Operand handle, Operand numElements, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("BarrierTakeMany", scope.makeOpName("BarrierTakeMany")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(numElements.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; + Class[] componentTypesArray = new Class[componentTypes.size()]; for (int i = 0; i < componentTypesArray.length; ++i) { componentTypesArray[i] = componentTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bitcast.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bitcast.java index b01c8598ae6..65d39833a78 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bitcast.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bitcast.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -96,7 +95,7 @@ public final class Bitcast extends RawOp implements Operand * @return a new instance of Bitcast */ @Endpoint(describeByClass = true) - public static Bitcast create(Scope scope, Operand input, DataType type) { + public static Bitcast create(Scope scope, Operand input, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("Bitcast", scope.makeOpName("Bitcast")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastDynamicShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastDynamicShape.java index 3027e8234a4..df29691cac8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastDynamicShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastDynamicShape.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Return the shape of s0 op s1 with broadcast. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastGradientArgs.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastGradientArgs.java index 2d95c71086e..dffa39e4bf1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastGradientArgs.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastGradientArgs.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Return the reduction indices for computing gradients of s0 op s1 with broadcast. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bucketize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bucketize.java index 84e9b454d0d..05b43dc7679 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bucketize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bucketize.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Bucketizes 'input' based on 'boundaries'. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CollectiveGather.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CollectiveGather.java index bd9de3ebb33..d1f4e84f146 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CollectiveGather.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CollectiveGather.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Mutually accumulates multiple tensors of identical type and shape. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CountUpTo.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CountUpTo.java index 2884783695d..811a955d2fd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CountUpTo.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CountUpTo.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Increments 'ref' until it reaches 'limit'. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DecodeProto.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DecodeProto.java index 28af55f8465..40afa08fc97 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DecodeProto.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DecodeProto.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -30,6 +29,7 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * The op extracts fields from a serialized protocol buffers message into tensors. @@ -135,7 +135,7 @@ private Options() { * @return a new instance of DecodeProto */ @Endpoint(describeByClass = true) - public static DecodeProto create(Scope scope, Operand bytes, String messageType, List fieldNames, List> outputTypes, Options... options) { + public static DecodeProto create(Scope scope, Operand bytes, String messageType, List fieldNames, List> outputTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("DecodeProtoV2", scope.makeOpName("DecodeProto")); opBuilder.addInput(bytes.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -145,7 +145,7 @@ public static DecodeProto create(Scope scope, Operand bytes, String mes fieldNamesArray[i] = fieldNames.get(i); } opBuilder.setAttr("field_names", fieldNamesArray); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DummySeedGenerator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DummySeedGenerator.java deleted file mode 100644 index b9cf2c36d09..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DummySeedGenerator.java +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.core; - -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.types.family.TType; - -/** - */ -public final class DummySeedGenerator extends RawOp implements Operand { - - /** - * Factory method to create a class wrapping a new DummySeedGenerator operation. - * - * @param scope current scope - * @return a new instance of DummySeedGenerator - */ - @Endpoint(describeByClass = true) - public static DummySeedGenerator create(Scope scope) { - OperationBuilder opBuilder = scope.env().opBuilder("DummySeedGenerator", scope.makeOpName("DummySeedGenerator")); - opBuilder = scope.applyControlDependencies(opBuilder); - return new DummySeedGenerator(opBuilder.build()); - } - - /** - */ - public Output handle() { - return handle; - } - - @Override - @SuppressWarnings("unchecked") - public Output asOutput() { - return (Output) handle; - } - - private Output handle; - - private DummySeedGenerator(Operation operation) { - super(operation); - int outputIdx = 0; - handle = operation.output(outputIdx++); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Empty.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Empty.java index c8305349d37..9330b2dfee3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Empty.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Empty.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -68,7 +67,7 @@ private Options() { * @return a new instance of Empty */ @Endpoint(describeByClass = true) - public static Empty create(Scope scope, Operand shape, DataType dtype, Options... options) { + public static Empty create(Scope scope, Operand shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Empty", scope.makeOpName("Empty")); opBuilder.addInput(shape.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/EmptyTensorList.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/EmptyTensorList.java index 619fb90657f..ecd16df5097 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/EmptyTensorList.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/EmptyTensorList.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -53,7 +52,7 @@ public final class EmptyTensorList extends RawOp implements Operand { * @return a new instance of EmptyTensorList */ @Endpoint(describeByClass = true) - public static EmptyTensorList create(Scope scope, Operand elementShape, Operand maxNumElements, DataType elementDtype) { + public static EmptyTensorList create(Scope scope, Operand elementShape, Operand maxNumElements, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("EmptyTensorList", scope.makeOpName("EmptyTensorList")); opBuilder.addInput(elementShape.asOutput()); opBuilder.addInput(maxNumElements.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ExtractVolumePatches.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ExtractVolumePatches.java index 0b9bcd78e20..487289a8a48 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ExtractVolumePatches.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ExtractVolumePatches.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Extract `patches` from `input` and put them in the "depth" output dimension. 3D extension of `extract_image_patches`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GetSessionTensor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GetSessionTensor.java index 93ba5af508c..d5767515e56 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GetSessionTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GetSessionTensor.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,7 +45,7 @@ public final class GetSessionTensor extends RawOp implements Op * @return a new instance of GetSessionTensor */ @Endpoint(describeByClass = true) - public static GetSessionTensor create(Scope scope, Operand handle, DataType dtype) { + public static GetSessionTensor create(Scope scope, Operand handle, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("GetSessionTensor", scope.makeOpName("GetSessionTensor")); opBuilder.addInput(handle.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HashTable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HashTable.java index 87d9cab4c3f..aea656de1ec 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HashTable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HashTable.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -88,7 +87,7 @@ private Options() { * @return a new instance of HashTable */ @Endpoint(describeByClass = true) - public static HashTable create(Scope scope, DataType keyDtype, DataType valueDtype, Options... options) { + public static HashTable create(Scope scope, Class keyDtype, Class valueDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("HashTableV2", scope.makeOpName("HashTable")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("key_dtype", keyDtype); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HistogramFixedWidth.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HistogramFixedWidth.java index da1c1a7b713..1565dd2a8d0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HistogramFixedWidth.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HistogramFixedWidth.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Return histogram of values. @@ -67,7 +65,7 @@ public final class HistogramFixedWidth extends RawOp implemen * @return a new instance of HistogramFixedWidth */ @Endpoint(describeByClass = true) - public static HistogramFixedWidth create(Scope scope, Operand values, Operand valueRange, Operand nbins, DataType dtype) { + public static HistogramFixedWidth create(Scope scope, Operand values, Operand valueRange, Operand nbins, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("HistogramFixedWidth", scope.makeOpName("HistogramFixedWidth")); opBuilder.addInput(values.asOutput()); opBuilder.addInput(valueRange.asOutput()); @@ -90,7 +88,7 @@ public static HistogramFixedWidth crea */ @Endpoint(describeByClass = true) public static HistogramFixedWidth create(Scope scope, Operand values, Operand valueRange, Operand nbins) { - return create(scope, values, valueRange, nbins, TInt32.DTYPE); + return create(scope, values, valueRange, nbins, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ImmutableConst.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ImmutableConst.java index ecbc3154498..3d2ac9625da 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ImmutableConst.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ImmutableConst.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -50,7 +49,7 @@ public final class ImmutableConst extends RawOp implements Oper * @return a new instance of ImmutableConst */ @Endpoint(describeByClass = true) - public static ImmutableConst create(Scope scope, DataType dtype, Shape shape, String memoryRegionName) { + public static ImmutableConst create(Scope scope, Class dtype, Shape shape, String memoryRegionName) { OperationBuilder opBuilder = scope.env().opBuilder("ImmutableConst", scope.makeOpName("ImmutableConst")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("dtype", dtype); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LinSpace.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LinSpace.java index 04f3959820a..fbfa0e69881 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LinSpace.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LinSpace.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Generates values in an interval. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableExport.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableExport.java index 6bebbb35895..527e70e7fc0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableExport.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableExport.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,7 +46,7 @@ public final class LookupTableExport extends R * @return a new instance of LookupTableExport */ @Endpoint(describeByClass = true) - public static LookupTableExport create(Scope scope, Operand tableHandle, DataType Tkeys, DataType Tvalues) { + public static LookupTableExport create(Scope scope, Operand tableHandle, Class Tkeys, Class Tvalues) { OperationBuilder opBuilder = scope.env().opBuilder("LookupTableExportV2", scope.makeOpName("LookupTableExport")); opBuilder.addInput(tableHandle.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LowerBound.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LowerBound.java index 6f0f5158cca..7ec77e7df0b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LowerBound.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LowerBound.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -66,7 +65,7 @@ public final class LowerBound extends RawOp implements Operan * @return a new instance of LowerBound */ @Endpoint(describeByClass = true) - public static LowerBound create(Scope scope, Operand sortedInputs, Operand values, DataType outType) { + public static LowerBound create(Scope scope, Operand sortedInputs, Operand values, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("LowerBound", scope.makeOpName("LowerBound")); opBuilder.addInput(sortedInputs.asOutput()); opBuilder.addInput(values.asOutput()); @@ -86,7 +85,7 @@ public static LowerBound create(Scope sc */ @Endpoint(describeByClass = true) public static LowerBound create(Scope scope, Operand sortedInputs, Operand values) { - return create(scope, sortedInputs, values, TInt32.DTYPE); + return create(scope, sortedInputs, values, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java index bad1e90554f..3c3dca46cc6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java @@ -18,13 +18,13 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; /** * Op removes all elements in the underlying container. @@ -87,10 +87,10 @@ private Options() { * @return a new instance of MapClear */ @Endpoint(describeByClass = true) - public static MapClear create(Scope scope, List> dtypes, Options... options) { + public static MapClear create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapClear", scope.makeOpName("MapClear")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapIncompleteSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapIncompleteSize.java index 19e9e87a08a..76b1513c6bd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapIncompleteSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapIncompleteSize.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,6 +27,7 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; /** * Op returns the number of incomplete elements in the underlying container. @@ -90,10 +90,10 @@ private Options() { * @return a new instance of MapIncompleteSize */ @Endpoint(describeByClass = true) - public static MapIncompleteSize create(Scope scope, List> dtypes, Options... options) { + public static MapIncompleteSize create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapIncompleteSize", scope.makeOpName("MapIncompleteSize")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapPeek.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapPeek.java index 1925ca680ea..46dccfc6a3f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapPeek.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapPeek.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -99,12 +98,12 @@ private Options() { * @return a new instance of MapPeek */ @Endpoint(describeByClass = true) - public static MapPeek create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { + public static MapPeek create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapPeek", scope.makeOpName("MapPeek")); opBuilder.addInput(key.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapSize.java index 7f4eea906f5..600b9473da9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapSize.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,6 +27,7 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; /** * Op returns the number of elements in the underlying container. @@ -90,10 +90,10 @@ private Options() { * @return a new instance of MapSize */ @Endpoint(describeByClass = true) - public static MapSize create(Scope scope, List> dtypes, Options... options) { + public static MapSize create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapSize", scope.makeOpName("MapSize")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java index 9291b32d53b..eb75d99ba5c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,6 +28,7 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; /** * Stage (key, values) in the underlying container which behaves like a hashtable. @@ -97,13 +97,13 @@ private Options() { * @return a new instance of MapStage */ @Endpoint(describeByClass = true) - public static MapStage create(Scope scope, Operand key, Operand indices, Iterable> values, List> dtypes, Options... options) { + public static MapStage create(Scope scope, Operand key, Operand indices, Iterable> values, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapStage", scope.makeOpName("MapStage")); opBuilder.addInput(key.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder.addInputList(Operands.asOutputs(values)); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstage.java index 849f6f3ef6a..ba02210d5c4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstage.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -99,12 +98,12 @@ private Options() { * @return a new instance of MapUnstage */ @Endpoint(describeByClass = true) - public static MapUnstage create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { + public static MapUnstage create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapUnstage", scope.makeOpName("MapUnstage")); opBuilder.addInput(key.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstageNoKey.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstageNoKey.java index 10a119ec6c4..f6cfe6c1d73 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstageNoKey.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstageNoKey.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -30,6 +29,7 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; /** * Op removes and returns a random (key, value) @@ -96,11 +96,11 @@ private Options() { * @return a new instance of MapUnstageNoKey */ @Endpoint(describeByClass = true) - public static MapUnstageNoKey create(Scope scope, Operand indices, List> dtypes, Options... options) { + public static MapUnstageNoKey create(Scope scope, Operand indices, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapUnstageNoKey", scope.makeOpName("MapUnstageNoKey")); opBuilder.addInput(indices.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MlirPassthroughOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MlirPassthroughOp.java index cc278fdae8f..027c75bf163 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MlirPassthroughOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MlirPassthroughOp.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -78,12 +77,12 @@ public final class MlirPassthroughOp extends RawOp implements Iterable> inputs, String mlirModule, List> Toutputs) { + public static MlirPassthroughOp create(Scope scope, Iterable> inputs, String mlirModule, List> Toutputs) { OperationBuilder opBuilder = scope.env().opBuilder("MlirPassthroughOp", scope.makeOpName("MlirPassthroughOp")); opBuilder.addInputList(Operands.asOutputs(inputs)); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("mlir_module", mlirModule); - DataType[] ToutputsArray = new DataType[Toutputs.size()]; + Class[] ToutputsArray = new Class[Toutputs.size()]; for (int i = 0; i < ToutputsArray.length; ++i) { ToutputsArray[i] = Toutputs.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableDenseHashTable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableDenseHashTable.java index 4ff3c1f3ea6..0e6b35f502a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableDenseHashTable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableDenseHashTable.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -122,7 +121,7 @@ private Options() { * @return a new instance of MutableDenseHashTable */ @Endpoint(describeByClass = true) - public static MutableDenseHashTable create(Scope scope, Operand emptyKey, Operand deletedKey, DataType valueDtype, Options... options) { + public static MutableDenseHashTable create(Scope scope, Operand emptyKey, Operand deletedKey, Class valueDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MutableDenseHashTableV2", scope.makeOpName("MutableDenseHashTable")); opBuilder.addInput(emptyKey.asOutput()); opBuilder.addInput(deletedKey.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTable.java index a11789f9f34..0e2dd872331 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTable.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -88,7 +87,7 @@ private Options() { * @return a new instance of MutableHashTable */ @Endpoint(describeByClass = true) - public static MutableHashTable create(Scope scope, DataType keyDtype, DataType valueDtype, Options... options) { + public static MutableHashTable create(Scope scope, Class keyDtype, Class valueDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MutableHashTableV2", scope.makeOpName("MutableHashTable")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("key_dtype", keyDtype); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTableOfTensors.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTableOfTensors.java index 975f040ae3b..49d2e3c0373 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTableOfTensors.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTableOfTensors.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -97,7 +96,7 @@ private Options() { * @return a new instance of MutableHashTableOfTensors */ @Endpoint(describeByClass = true) - public static MutableHashTableOfTensors create(Scope scope, DataType keyDtype, DataType valueDtype, Options... options) { + public static MutableHashTableOfTensors create(Scope scope, Class keyDtype, Class valueDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MutableHashTableOfTensorsV2", scope.makeOpName("MutableHashTableOfTensors")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("key_dtype", keyDtype); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclAllReduce.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclAllReduce.java index a44611b5a41..171cf01d750 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclAllReduce.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclAllReduce.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs a tensor containing the reduction across all input tensors. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclBroadcast.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclBroadcast.java index da0832e437a..17db313a42e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclBroadcast.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclBroadcast.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Sends `input` to all devices that are connected to the output. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclReduce.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclReduce.java index ced1473e60a..6156dbc5551 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclReduce.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclReduce.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Reduces `input` from `num_devices` using `reduction` to a single device. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java index 05a1b7ab984..df4b4f26bfc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java @@ -18,13 +18,13 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; /** * Op removes all elements in the underlying container. @@ -87,10 +87,10 @@ private Options() { * @return a new instance of OrderedMapClear */ @Endpoint(describeByClass = true) - public static OrderedMapClear create(Scope scope, List> dtypes, Options... options) { + public static OrderedMapClear create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapClear", scope.makeOpName("OrderedMapClear")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapIncompleteSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapIncompleteSize.java index 865810568db..1b686b2f14b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapIncompleteSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapIncompleteSize.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,6 +27,7 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; /** * Op returns the number of incomplete elements in the underlying container. @@ -90,10 +90,10 @@ private Options() { * @return a new instance of OrderedMapIncompleteSize */ @Endpoint(describeByClass = true) - public static OrderedMapIncompleteSize create(Scope scope, List> dtypes, Options... options) { + public static OrderedMapIncompleteSize create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapIncompleteSize", scope.makeOpName("OrderedMapIncompleteSize")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapPeek.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapPeek.java index 21c6adcc039..6b052e5f00c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapPeek.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapPeek.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -100,12 +99,12 @@ private Options() { * @return a new instance of OrderedMapPeek */ @Endpoint(describeByClass = true) - public static OrderedMapPeek create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { + public static OrderedMapPeek create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapPeek", scope.makeOpName("OrderedMapPeek")); opBuilder.addInput(key.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapSize.java index afdee7de1bd..69c7851b13c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapSize.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,6 +27,7 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; /** * Op returns the number of elements in the underlying container. @@ -90,10 +90,10 @@ private Options() { * @return a new instance of OrderedMapSize */ @Endpoint(describeByClass = true) - public static OrderedMapSize create(Scope scope, List> dtypes, Options... options) { + public static OrderedMapSize create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapSize", scope.makeOpName("OrderedMapSize")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java index 7e02973e3c6..12f3155d661 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,6 +28,7 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; /** * Stage (key, values) in the underlying container which behaves like a ordered @@ -99,13 +99,13 @@ private Options() { * @return a new instance of OrderedMapStage */ @Endpoint(describeByClass = true) - public static OrderedMapStage create(Scope scope, Operand key, Operand indices, Iterable> values, List> dtypes, Options... options) { + public static OrderedMapStage create(Scope scope, Operand key, Operand indices, Iterable> values, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapStage", scope.makeOpName("OrderedMapStage")); opBuilder.addInput(key.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder.addInputList(Operands.asOutputs(values)); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstage.java index e2460e42dd8..8b08d7449c2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstage.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -99,12 +98,12 @@ private Options() { * @return a new instance of OrderedMapUnstage */ @Endpoint(describeByClass = true) - public static OrderedMapUnstage create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { + public static OrderedMapUnstage create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapUnstage", scope.makeOpName("OrderedMapUnstage")); opBuilder.addInput(key.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstageNoKey.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstageNoKey.java index f20a23b9806..edb9e82401a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstageNoKey.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstageNoKey.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -30,6 +29,7 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; /** * Op removes and returns the (key, value) element with the smallest @@ -96,11 +96,11 @@ private Options() { * @return a new instance of OrderedMapUnstageNoKey */ @Endpoint(describeByClass = true) - public static OrderedMapUnstageNoKey create(Scope scope, Operand indices, List> dtypes, Options... options) { + public static OrderedMapUnstageNoKey create(Scope scope, Operand indices, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapUnstageNoKey", scope.makeOpName("OrderedMapUnstageNoKey")); opBuilder.addInput(indices.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Placeholder.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Placeholder.java index caef9fc0783..75d2326d007 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Placeholder.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Placeholder.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -70,7 +69,7 @@ private Options() { * @return a new instance of Placeholder */ @Endpoint(describeByClass = true) - public static Placeholder create(Scope scope, DataType dtype, Options... options) { + public static Placeholder create(Scope scope, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Placeholder", scope.makeOpName("Placeholder")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("dtype", dtype); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Range.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Range.java index 967a490382e..4f71eafe3c0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Range.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Range.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Creates a sequence of numbers. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReadVariableOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReadVariableOp.java index aa2d1d2b9b2..6f33217ec56 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReadVariableOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReadVariableOp.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -52,7 +51,7 @@ public final class ReadVariableOp extends RawOp implements Oper * @return a new instance of ReadVariableOp */ @Endpoint(describeByClass = true) - public static ReadVariableOp create(Scope scope, Operand resource, DataType dtype) { + public static ReadVariableOp create(Scope scope, Operand resource, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("ReadVariableOp", scope.makeOpName("ReadVariableOp")); opBuilder.addInput(resource.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Recv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Recv.java index db6511c1583..81ebf79a0e8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Recv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Recv.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -70,7 +69,7 @@ private Options() { * @return a new instance of Recv */ @Endpoint(describeByClass = true) - public static Recv create(Scope scope, DataType tensorType, String tensorName, String sendDevice, Long sendDeviceIncarnation, String recvDevice, Options... options) { + public static Recv create(Scope scope, Class tensorType, String tensorName, String sendDevice, Long sendDeviceIncarnation, String recvDevice, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Recv", scope.makeOpName("Recv")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("tensor_type", tensorType); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAll.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAll.java index 9a5ad026ac8..1e940c95fb5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAll.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAll.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the "logical and" of elements across dimensions of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAny.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAny.java index de479629f97..412e00fd07e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAny.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAny.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the "logical or" of elements across dimensions of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteFusedGraphExecute.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteFusedGraphExecute.java index bd76549c976..f1a70aa35e6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteFusedGraphExecute.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteFusedGraphExecute.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -57,11 +56,11 @@ public final class RemoteFusedGraphExecute extends RawOp implements Iterable> inputs, List> Toutputs, String serializedRemoteFusedGraphExecuteInfo) { + public static RemoteFusedGraphExecute create(Scope scope, Iterable> inputs, List> Toutputs, String serializedRemoteFusedGraphExecuteInfo) { OperationBuilder opBuilder = scope.env().opBuilder("RemoteFusedGraphExecute", scope.makeOpName("RemoteFusedGraphExecute")); opBuilder.addInputList(Operands.asOutputs(inputs)); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] ToutputsArray = new DataType[Toutputs.size()]; + Class[] ToutputsArray = new Class[Toutputs.size()]; for (int i = 0; i < ToutputsArray.length; ++i) { ToutputsArray[i] = Toutputs.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceCountUpTo.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceCountUpTo.java index 19f630dd014..072c70a1862 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceCountUpTo.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceCountUpTo.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -27,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Increments variable pointed to by 'resource' until it reaches 'limit'. @@ -48,7 +46,7 @@ public final class ResourceCountUpTo extends RawOp implements * @return a new instance of ResourceCountUpTo */ @Endpoint(describeByClass = true) - public static ResourceCountUpTo create(Scope scope, Operand resource, Long limit, DataType T) { + public static ResourceCountUpTo create(Scope scope, Operand resource, Long limit, Class T) { OperationBuilder opBuilder = scope.env().opBuilder("ResourceCountUpTo", scope.makeOpName("ResourceCountUpTo")); opBuilder.addInput(resource.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGather.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGather.java index 38aa1fbd407..d470cf09127 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGather.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGather.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -90,7 +89,7 @@ private Options() { * @return a new instance of ResourceGather */ @Endpoint(describeByClass = true) - public static ResourceGather create(Scope scope, Operand resource, Operand indices, DataType dtype, Options... options) { + public static ResourceGather create(Scope scope, Operand resource, Operand indices, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ResourceGather", scope.makeOpName("ResourceGather")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(indices.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGatherNd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGatherNd.java index 85e422179f7..3f8237d98e6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGatherNd.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGatherNd.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -45,7 +44,7 @@ public final class ResourceGatherNd extends RawOp implements Op * @return a new instance of ResourceGatherNd */ @Endpoint(describeByClass = true) - public static ResourceGatherNd create(Scope scope, Operand resource, Operand indices, DataType dtype) { + public static ResourceGatherNd create(Scope scope, Operand resource, Operand indices, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("ResourceGatherNd", scope.makeOpName("ResourceGatherNd")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(indices.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMax.java index 20b84f016ea..360cb496ee8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMax.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Reduces sparse updates into a variable reference using the `max` operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMin.java index 5010da14eaa..14f83e0756a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMin.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Reduces sparse updates into a variable reference using the `min` operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/SetDiff1d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/SetDiff1d.java index c7e458931e5..d1c53319c4f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/SetDiff1d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/SetDiff1d.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -69,7 +68,7 @@ public final class SetDiff1d extends RawOp { * @return a new instance of SetDiff1d */ @Endpoint(describeByClass = true) - public static SetDiff1d create(Scope scope, Operand x, Operand y, DataType outIdx) { + public static SetDiff1d create(Scope scope, Operand x, Operand y, Class outIdx) { OperationBuilder opBuilder = scope.env().opBuilder("ListDiff", scope.makeOpName("SetDiff1d")); opBuilder.addInput(x.asOutput()); opBuilder.addInput(y.asOutput()); @@ -88,7 +87,7 @@ public static SetDiff1d create(Scope */ @Endpoint(describeByClass = true) public static SetDiff1d create(Scope scope, Operand x, Operand y) { - return create(scope, x, y, TInt32.DTYPE); + return create(scope, x, y, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Shape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Shape.java index 5d613401fc5..3ee92514a35 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Shape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Shape.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -56,7 +55,7 @@ public final class Shape extends RawOp implements Operand * @return a new instance of Shape */ @Endpoint(describeByClass = true) - public static Shape create(Scope scope, Operand input, DataType outType) { + public static Shape create(Scope scope, Operand input, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("Shape", scope.makeOpName("Shape")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -73,7 +72,7 @@ public static Shape create(Scope scope, */ @Endpoint(describeByClass = true) public static Shape create(Scope scope, Operand input) { - return create(scope, input, TInt32.DTYPE); + return create(scope, input, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ShapeN.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ShapeN.java index 42340a8a83d..1b1af62e7ac 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ShapeN.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ShapeN.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -53,7 +52,7 @@ public final class ShapeN extends RawOp implements Iterable ShapeN create(Scope scope, Iterable> input, DataType outType) { + public static ShapeN create(Scope scope, Iterable> input, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("ShapeN", scope.makeOpName("ShapeN")); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.applyControlDependencies(opBuilder); @@ -70,7 +69,7 @@ public static ShapeN create(Scope scope, */ @Endpoint(describeByClass = true) public static ShapeN create(Scope scope, Iterable> input) { - return create(scope, input, TInt32.DTYPE); + return create(scope, input, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Size.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Size.java index 4aaccae0f30..ddd912f68f8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Size.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Size.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -57,7 +56,7 @@ public final class Size extends RawOp implements Operand { * @return a new instance of Size */ @Endpoint(describeByClass = true) - public static Size create(Scope scope, Operand input, DataType outType) { + public static Size create(Scope scope, Operand input, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("Size", scope.makeOpName("Size")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -74,7 +73,7 @@ public static Size create(Scope scope, O */ @Endpoint(describeByClass = true) public static Size create(Scope scope, Operand input) { - return create(scope, input, TInt32.DTYPE); + return create(scope, input, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java index 60e51559f74..1bb2fde0fdf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java @@ -18,13 +18,13 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; /** * Op removes all elements in the underlying container. @@ -87,10 +87,10 @@ private Options() { * @return a new instance of StageClear */ @Endpoint(describeByClass = true) - public static StageClear create(Scope scope, List> dtypes, Options... options) { + public static StageClear create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("StageClear", scope.makeOpName("StageClear")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StagePeek.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StagePeek.java index 2126de722df..07ce640c947 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StagePeek.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StagePeek.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -98,11 +97,11 @@ private Options() { * @return a new instance of StagePeek */ @Endpoint(describeByClass = true) - public static StagePeek create(Scope scope, Operand index, List> dtypes, Options... options) { + public static StagePeek create(Scope scope, Operand index, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("StagePeek", scope.makeOpName("StagePeek")); opBuilder.addInput(index.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageSize.java index 94ef566e708..f67f23e5bd3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageSize.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,6 +27,7 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; /** * Op returns the number of elements in the underlying container. @@ -90,10 +90,10 @@ private Options() { * @return a new instance of StageSize */ @Endpoint(describeByClass = true) - public static StageSize create(Scope scope, List> dtypes, Options... options) { + public static StageSize create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("StageSize", scope.makeOpName("StageSize")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TemporaryVariable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TemporaryVariable.java index 43d9247fe21..452d380f859 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TemporaryVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TemporaryVariable.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -82,7 +81,7 @@ private Options() { * @return a new instance of TemporaryVariable */ @Endpoint(describeByClass = true) - public static TemporaryVariable create(Scope scope, Shape shape, DataType dtype, Options... options) { + public static TemporaryVariable create(Scope scope, Shape shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TemporaryVariable", scope.makeOpName("TemporaryVariable")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("shape", shape); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArray.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArray.java index f34dc414484..fb6014a5efd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArray.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArray.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -116,7 +115,7 @@ private Options() { * @return a new instance of TensorArray */ @Endpoint(describeByClass = true) - public static TensorArray create(Scope scope, Operand size, DataType dtype, Options... options) { + public static TensorArray create(Scope scope, Operand size, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TensorArrayV3", scope.makeOpName("TensorArray")); opBuilder.addInput(size.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayConcat.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayConcat.java index e25b7a2f958..47d78d77f1d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayConcat.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayConcat.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -83,7 +82,7 @@ private Options() { * @return a new instance of TensorArrayConcat */ @Endpoint(describeByClass = true) - public static TensorArrayConcat create(Scope scope, Operand handle, Operand flowIn, DataType dtype, Options... options) { + public static TensorArrayConcat create(Scope scope, Operand handle, Operand flowIn, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TensorArrayConcatV3", scope.makeOpName("TensorArrayConcat")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(flowIn.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayGather.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayGather.java index 13c6dce3122..9ba3bca49f7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayGather.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayGather.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -74,7 +73,7 @@ private Options() { * @return a new instance of TensorArrayGather */ @Endpoint(describeByClass = true) - public static TensorArrayGather create(Scope scope, Operand handle, Operand indices, Operand flowIn, DataType dtype, Options... options) { + public static TensorArrayGather create(Scope scope, Operand handle, Operand indices, Operand flowIn, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TensorArrayGatherV3", scope.makeOpName("TensorArrayGather")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(indices.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayPack.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayPack.java index 4b4be4ac089..d622be4d3e7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayPack.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayPack.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -67,7 +66,7 @@ private Options() { * @return a new instance of TensorArrayPack */ @Endpoint(describeByClass = true) - public static TensorArrayPack create(Scope scope, Operand handle, Operand flowIn, DataType dtype, Options... options) { + public static TensorArrayPack create(Scope scope, Operand handle, Operand flowIn, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TensorArrayPack", scope.makeOpName("TensorArrayPack")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(flowIn.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayRead.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayRead.java index 55cf48bb020..b71aa6c58a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayRead.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayRead.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,7 +48,7 @@ public final class TensorArrayRead extends RawOp implements Ope * @return a new instance of TensorArrayRead */ @Endpoint(describeByClass = true) - public static TensorArrayRead create(Scope scope, Operand handle, Operand index, Operand flowIn, DataType dtype) { + public static TensorArrayRead create(Scope scope, Operand handle, Operand index, Operand flowIn, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorArrayReadV3", scope.makeOpName("TensorArrayRead")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(index.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcat.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcat.java index 2f1771797e0..6f83df9232c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcat.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcat.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -62,7 +61,7 @@ public final class TensorListConcat extends RawOp { * @return a new instance of TensorListConcat */ @Endpoint(describeByClass = true) - public static TensorListConcat create(Scope scope, Operand inputHandle, Operand elementShape, Operand leadingDims, DataType elementDtype) { + public static TensorListConcat create(Scope scope, Operand inputHandle, Operand elementShape, Operand leadingDims, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListConcatV2", scope.makeOpName("TensorListConcat")); opBuilder.addInput(inputHandle.asOutput()); opBuilder.addInput(elementShape.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcatLists.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcatLists.java index fdca8e2d6cd..285bc6e734a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcatLists.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcatLists.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -43,7 +42,7 @@ public final class TensorListConcatLists extends RawOp implements Operand * @return a new instance of TensorListConcatLists */ @Endpoint(describeByClass = true) - public static TensorListConcatLists create(Scope scope, Operand inputA, Operand inputB, DataType elementDtype) { + public static TensorListConcatLists create(Scope scope, Operand inputA, Operand inputB, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListConcatLists", scope.makeOpName("TensorListConcatLists")); opBuilder.addInput(inputA.asOutput()); opBuilder.addInput(inputB.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListElementShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListElementShape.java index 2c3e3a5b90c..17dfddcb958 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListElementShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListElementShape.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -27,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * The shape of the elements of the given list, as a tensor. @@ -49,7 +47,7 @@ public final class TensorListElementShape extends RawOp imple * @return a new instance of TensorListElementShape */ @Endpoint(describeByClass = true) - public static TensorListElementShape create(Scope scope, Operand inputHandle, DataType shapeType) { + public static TensorListElementShape create(Scope scope, Operand inputHandle, Class shapeType) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListElementShape", scope.makeOpName("TensorListElementShape")); opBuilder.addInput(inputHandle.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGather.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGather.java index dbb565a29a6..d99e658f936 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGather.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGather.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -55,7 +54,7 @@ public final class TensorListGather extends RawOp implements Op * @return a new instance of TensorListGather */ @Endpoint(describeByClass = true) - public static TensorListGather create(Scope scope, Operand inputHandle, Operand indices, Operand elementShape, DataType elementDtype) { + public static TensorListGather create(Scope scope, Operand inputHandle, Operand indices, Operand elementShape, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListGather", scope.makeOpName("TensorListGather")); opBuilder.addInput(inputHandle.asOutput()); opBuilder.addInput(indices.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGetItem.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGetItem.java index 9af925e1425..ed92723ab9f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGetItem.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGetItem.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,7 +45,7 @@ public final class TensorListGetItem extends RawOp implements O * @return a new instance of TensorListGetItem */ @Endpoint(describeByClass = true) - public static TensorListGetItem create(Scope scope, Operand inputHandle, Operand index, Operand elementShape, DataType elementDtype) { + public static TensorListGetItem create(Scope scope, Operand inputHandle, Operand index, Operand elementShape, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListGetItem", scope.makeOpName("TensorListGetItem")); opBuilder.addInput(inputHandle.asOutput()); opBuilder.addInput(index.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListPopBack.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListPopBack.java index 96fd3433e07..1cb362c913c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListPopBack.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListPopBack.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -54,7 +53,7 @@ public final class TensorListPopBack extends RawOp { * @return a new instance of TensorListPopBack */ @Endpoint(describeByClass = true) - public static TensorListPopBack create(Scope scope, Operand inputHandle, Operand elementShape, DataType elementDtype) { + public static TensorListPopBack create(Scope scope, Operand inputHandle, Operand elementShape, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListPopBack", scope.makeOpName("TensorListPopBack")); opBuilder.addInput(inputHandle.asOutput()); opBuilder.addInput(elementShape.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListReserve.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListReserve.java index 3f798e0f998..e7c6abda2c5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListReserve.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListReserve.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -51,7 +50,7 @@ public final class TensorListReserve extends RawOp implements Operand { * @return a new instance of TensorListReserve */ @Endpoint(describeByClass = true) - public static TensorListReserve create(Scope scope, Operand elementShape, Operand numElements, DataType elementDtype) { + public static TensorListReserve create(Scope scope, Operand elementShape, Operand numElements, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListReserve", scope.makeOpName("TensorListReserve")); opBuilder.addInput(elementShape.asOutput()); opBuilder.addInput(numElements.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListStack.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListStack.java index c0ecf388b3d..dfc2357d18c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListStack.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListStack.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -74,7 +73,7 @@ private Options() { * @return a new instance of TensorListStack */ @Endpoint(describeByClass = true) - public static TensorListStack create(Scope scope, Operand inputHandle, Operand elementShape, DataType elementDtype, Options... options) { + public static TensorListStack create(Scope scope, Operand inputHandle, Operand elementShape, Class elementDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListStack", scope.makeOpName("TensorListStack")); opBuilder.addInput(inputHandle.asOutput()); opBuilder.addInput(elementShape.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterMax.java deleted file mode 100644 index 1a3042926f6..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterMax.java +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.core; - -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; - -/** - * @param data type for {@code output()} output - */ -@Operator -public final class TensorScatterMax extends RawOp implements Operand { - - /** - * Factory method to create a class wrapping a new TensorScatterMax operation. - * - * @param scope current scope - * @param tensor Tensor to update. - * @param indices Index tensor. - * @param updates Updates to scatter into output. - * @return a new instance of TensorScatterMax - */ - @Endpoint(describeByClass = true) - public static TensorScatterMax create(Scope scope, Operand tensor, Operand indices, Operand updates) { - OperationBuilder opBuilder = scope.env().opBuilder("TensorScatterMax", scope.makeOpName("TensorScatterMax")); - opBuilder.addInput(tensor.asOutput()); - opBuilder.addInput(indices.asOutput()); - opBuilder.addInput(updates.asOutput()); - opBuilder = scope.applyControlDependencies(opBuilder); - return new TensorScatterMax(opBuilder.build()); - } - - /** - * A new tensor copied from tensor whose values are element-wise maximum between tensor and updates according to the indices. - */ - public Output output() { - return output; - } - - @Override - public Output asOutput() { - return output; - } - - private Output output; - - private TensorScatterMax(Operation operation) { - super(operation); - int outputIdx = 0; - output = operation.output(outputIdx++); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterMin.java deleted file mode 100644 index 5f7b0d0eb77..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterMin.java +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.core; - -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; - -/** - * @param data type for {@code output()} output - */ -@Operator -public final class TensorScatterMin extends RawOp implements Operand { - - /** - * Factory method to create a class wrapping a new TensorScatterMin operation. - * - * @param scope current scope - * @param tensor Tensor to update. - * @param indices Index tensor. - * @param updates Updates to scatter into output. - * @return a new instance of TensorScatterMin - */ - @Endpoint(describeByClass = true) - public static TensorScatterMin create(Scope scope, Operand tensor, Operand indices, Operand updates) { - OperationBuilder opBuilder = scope.env().opBuilder("TensorScatterMin", scope.makeOpName("TensorScatterMin")); - opBuilder.addInput(tensor.asOutput()); - opBuilder.addInput(indices.asOutput()); - opBuilder.addInput(updates.asOutput()); - opBuilder = scope.applyControlDependencies(opBuilder); - return new TensorScatterMin(opBuilder.build()); - } - - /** - * A new tensor copied from tensor whose values are element-wise minimum between tensor and updates according to the indices. - */ - public Output output() { - return output; - } - - @Override - public Output asOutput() { - return output; - } - - private Output output; - - private TensorScatterMin(Operation operation) { - super(operation); - int outputIdx = 0; - output = operation.output(outputIdx++); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unique.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unique.java index f3b8b04aead..8d1bcc7a6ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unique.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unique.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -90,7 +89,7 @@ public final class Unique extends RawOp { * @return a new instance of Unique */ @Endpoint(describeByClass = true) - public static Unique create(Scope scope, Operand x, Operand axis, DataType outIdx) { + public static Unique create(Scope scope, Operand x, Operand axis, Class outIdx) { OperationBuilder opBuilder = scope.env().opBuilder("UniqueV2", scope.makeOpName("Unique")); opBuilder.addInput(x.asOutput()); opBuilder.addInput(axis.asOutput()); @@ -110,7 +109,7 @@ public static Unique Unique create(Scope scope, Operand x, Operand axis) { - return create(scope, x, axis, TInt32.DTYPE); + return create(scope, x, axis, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UniqueWithCounts.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UniqueWithCounts.java index 732d4432333..6316f3659c9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UniqueWithCounts.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UniqueWithCounts.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -94,7 +93,7 @@ public final class UniqueWithCounts extends * @return a new instance of UniqueWithCounts */ @Endpoint(describeByClass = true) - public static UniqueWithCounts create(Scope scope, Operand x, Operand axis, DataType outIdx) { + public static UniqueWithCounts create(Scope scope, Operand x, Operand axis, Class outIdx) { OperationBuilder opBuilder = scope.env().opBuilder("UniqueWithCountsV2", scope.makeOpName("UniqueWithCounts")); opBuilder.addInput(x.asOutput()); opBuilder.addInput(axis.asOutput()); @@ -114,7 +113,7 @@ public static UniqueWith */ @Endpoint(describeByClass = true) public static UniqueWithCounts create(Scope scope, Operand x, Operand axis) { - return create(scope, x, axis, TInt32.DTYPE); + return create(scope, x, axis, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UnravelIndex.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UnravelIndex.java index ecc073e37a4..37c5f2ded90 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UnravelIndex.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UnravelIndex.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Converts an array of flat indices into a tuple of coordinate arrays. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unstage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unstage.java index 1878ce6fb88..f6bd852f898 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unstage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unstage.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -95,10 +94,10 @@ private Options() { * @return a new instance of Unstage */ @Endpoint(describeByClass = true) - public static Unstage create(Scope scope, List> dtypes, Options... options) { + public static Unstage create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Unstage", scope.makeOpName("Unstage")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UpperBound.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UpperBound.java index c46fbcba4de..67d2e9255e4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UpperBound.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UpperBound.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -66,7 +65,7 @@ public final class UpperBound extends RawOp implements Operan * @return a new instance of UpperBound */ @Endpoint(describeByClass = true) - public static UpperBound create(Scope scope, Operand sortedInputs, Operand values, DataType outType) { + public static UpperBound create(Scope scope, Operand sortedInputs, Operand values, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("UpperBound", scope.makeOpName("UpperBound")); opBuilder.addInput(sortedInputs.asOutput()); opBuilder.addInput(values.asOutput()); @@ -86,7 +85,7 @@ public static UpperBound create(Scope sc */ @Endpoint(describeByClass = true) public static UpperBound create(Scope scope, Operand sortedInputs, Operand values) { - return create(scope, sortedInputs, values, TInt32.DTYPE); + return create(scope, sortedInputs, values, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VarHandleOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VarHandleOp.java index 92618c33d37..c264519b454 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VarHandleOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VarHandleOp.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -85,7 +84,7 @@ private Options() { * @return a new instance of VarHandleOp */ @Endpoint(describeByClass = true) - public static VarHandleOp create(Scope scope, DataType dtype, Shape shape, Options... options) { + public static VarHandleOp create(Scope scope, Class dtype, Shape shape, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("VarHandleOp", scope.makeOpName("VarHandleOp")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("dtype", dtype); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java index 7353dfee688..bcd9719c932 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -81,7 +80,7 @@ private Options() { * @return a new instance of Variable */ @Endpoint(describeByClass = true) - public static Variable create(Scope scope, Shape shape, DataType dtype, Options... options) { + public static Variable create(Scope scope, Shape shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("VariableV2", scope.makeOpName("Variable")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("shape", shape); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VariableShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VariableShape.java index 6573cb8de11..2a4c0ba60ec 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VariableShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VariableShape.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the shape of the variable pointed to by `resource`. @@ -56,7 +54,7 @@ public final class VariableShape extends RawOp implements Ope * @return a new instance of VariableShape */ @Endpoint(describeByClass = true) - public static VariableShape create(Scope scope, Operand input, DataType outType) { + public static VariableShape create(Scope scope, Operand input, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("VariableShape", scope.makeOpName("VariableShape")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -73,7 +71,7 @@ public static VariableShape create(Scope scope, Operand create(Scope scope, Operand input) { - return create(scope, input, TInt32.DTYPE); + return create(scope, input, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousIterator.java index bd3d1a400f5..8236e750ff5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousIterator.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; @@ -27,6 +26,7 @@ import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; /** * A container for an iterator resource. @@ -43,10 +43,10 @@ public final class AnonymousIterator extends RawOp { * @return a new instance of AnonymousIterator */ @Endpoint(describeByClass = true) - public static AnonymousIterator create(Scope scope, List> outputTypes, List outputShapes) { + public static AnonymousIterator create(Scope scope, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("AnonymousIteratorV2", scope.makeOpName("AnonymousIterator")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousMultiDeviceIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousMultiDeviceIterator.java index bff57b33c8f..9b5ba5b7fd9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousMultiDeviceIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousMultiDeviceIterator.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; @@ -27,6 +26,7 @@ import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; /** * A container for a multi device iterator resource. @@ -43,7 +43,7 @@ public final class AnonymousMultiDeviceIterator extends RawOp { * @return a new instance of AnonymousMultiDeviceIterator */ @Endpoint(describeByClass = true) - public static AnonymousMultiDeviceIterator create(Scope scope, List devices, List> outputTypes, List outputShapes) { + public static AnonymousMultiDeviceIterator create(Scope scope, List devices, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("AnonymousMultiDeviceIterator", scope.makeOpName("AnonymousMultiDeviceIterator")); opBuilder = scope.applyControlDependencies(opBuilder); String[] devicesArray = new String[devices.size()]; @@ -51,7 +51,7 @@ public static AnonymousMultiDeviceIterator create(Scope scope, List devi devicesArray[i] = devices.get(i); } opBuilder.setAttr("devices", devicesArray); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertNextDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertNextDataset.java index c2825e29b93..0fd73311cf3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertNextDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertNextDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -58,12 +57,12 @@ public final class AssertNextDataset extends RawOp implements Operand { * @return a new instance of AssertNextDataset */ @Endpoint(describeByClass = true) - public static AssertNextDataset create(Scope scope, Operand inputDataset, Operand transformations, List> outputTypes, List outputShapes) { + public static AssertNextDataset create(Scope scope, Operand inputDataset, Operand transformations, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("AssertNextDataset", scope.makeOpName("AssertNextDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(transformations.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AutoShardDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AutoShardDataset.java index aca90031261..c71a1b96f5e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AutoShardDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AutoShardDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -76,13 +75,13 @@ private Options() { * @return a new instance of AutoShardDataset */ @Endpoint(describeByClass = true) - public static AutoShardDataset create(Scope scope, Operand inputDataset, Operand numWorkers, Operand index, List> outputTypes, List outputShapes, Options... options) { + public static AutoShardDataset create(Scope scope, Operand inputDataset, Operand numWorkers, Operand index, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("AutoShardDataset", scope.makeOpName("AutoShardDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numWorkers.asOutput()); opBuilder.addInput(index.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BatchDataset.java index b0fd6ef0c0c..7409bdade59 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BatchDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -71,13 +70,13 @@ private Options() { * @return a new instance of BatchDataset */ @Endpoint(describeByClass = true) - public static BatchDataset create(Scope scope, Operand inputDataset, Operand batchSize, Operand dropRemainder, List> outputTypes, List outputShapes, Options... options) { + public static BatchDataset create(Scope scope, Operand inputDataset, Operand batchSize, Operand dropRemainder, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("BatchDatasetV2", scope.makeOpName("BatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(batchSize.asOutput()); opBuilder.addInput(dropRemainder.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BytesProducedStatsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BytesProducedStatsDataset.java index a3b95d92909..4000a238b86 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BytesProducedStatsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BytesProducedStatsDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,12 +46,12 @@ public final class BytesProducedStatsDataset extends RawOp implements Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { + public static BytesProducedStatsDataset create(Scope scope, Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("BytesProducedStatsDataset", scope.makeOpName("BytesProducedStatsDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(tag.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDataset.java index 07384fb6b3e..1de01e0eccd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -53,12 +52,12 @@ public final class CacheDataset extends RawOp implements Operand { * @return a new instance of CacheDataset */ @Endpoint(describeByClass = true) - public static CacheDataset create(Scope scope, Operand inputDataset, Operand filename, List> outputTypes, List outputShapes) { + public static CacheDataset create(Scope scope, Operand inputDataset, Operand filename, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("CacheDataset", scope.makeOpName("CacheDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(filename.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDatasetV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDatasetV2.java index ba370108e04..3bcab52b253 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDatasetV2.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDatasetV2.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,13 +46,13 @@ public final class CacheDatasetV2 extends RawOp implements Operand { * @return a new instance of CacheDatasetV2 */ @Endpoint(describeByClass = true) - public static CacheDatasetV2 create(Scope scope, Operand inputDataset, Operand filename, Operand cache, List> outputTypes, List outputShapes) { + public static CacheDatasetV2 create(Scope scope, Operand inputDataset, Operand filename, Operand cache, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("CacheDatasetV2", scope.makeOpName("CacheDatasetV2")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(filename.asOutput()); opBuilder.addInput(cache.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestDataset.java index a8d9ce445fe..74f34bb8023 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,12 +45,12 @@ public final class ChooseFastestDataset extends RawOp implements Operand * @return a new instance of ChooseFastestDataset */ @Endpoint(describeByClass = true) - public static ChooseFastestDataset create(Scope scope, Iterable> inputDatasets, Long numExperiments, List> outputTypes, List outputShapes) { + public static ChooseFastestDataset create(Scope scope, Iterable> inputDatasets, Long numExperiments, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ChooseFastestDataset", scope.makeOpName("ChooseFastestDataset")); opBuilder.addInputList(Operands.asOutputs(inputDatasets)); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("num_experiments", numExperiments); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ConcatenateDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ConcatenateDataset.java index 19cf0b4a706..34326835089 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ConcatenateDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ConcatenateDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,12 +46,12 @@ public final class ConcatenateDataset extends RawOp implements Operand { * @return a new instance of ConcatenateDataset */ @Endpoint(describeByClass = true) - public static ConcatenateDataset create(Scope scope, Operand inputDataset, Operand anotherDataset, List> outputTypes, List outputShapes) { + public static ConcatenateDataset create(Scope scope, Operand inputDataset, Operand anotherDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ConcatenateDataset", scope.makeOpName("ConcatenateDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(anotherDataset.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToSingleElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToSingleElement.java index 91fd5032439..64a20123b56 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToSingleElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToSingleElement.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,11 +46,11 @@ public final class DatasetToSingleElement extends RawOp implements Iterable dataset, List> outputTypes, List outputShapes) { + public static DatasetToSingleElement create(Scope scope, Operand dataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("DatasetToSingleElement", scope.makeOpName("DatasetToSingleElement")); opBuilder.addInput(dataset.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DenseToSparseBatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DenseToSparseBatchDataset.java index 32bd135c325..1882f3e8beb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DenseToSparseBatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DenseToSparseBatchDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -51,13 +50,13 @@ public final class DenseToSparseBatchDataset extends RawOp implements Operand inputDataset, Operand batchSize, Operand rowShape, List> outputTypes, List outputShapes) { + public static DenseToSparseBatchDataset create(Scope scope, Operand inputDataset, Operand batchSize, Operand rowShape, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("DenseToSparseBatchDataset", scope.makeOpName("DenseToSparseBatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(batchSize.asOutput()); opBuilder.addInput(rowShape.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DirectedInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DirectedInterleaveDataset.java index 6883e2218d9..4bf4ecdf141 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DirectedInterleaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DirectedInterleaveDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,12 +48,12 @@ public final class DirectedInterleaveDataset extends RawOp implements Operand selectorInputDataset, Iterable> dataInputDatasets, List> outputTypes, List outputShapes) { + public static DirectedInterleaveDataset create(Scope scope, Operand selectorInputDataset, Iterable> dataInputDatasets, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("DirectedInterleaveDataset", scope.makeOpName("DirectedInterleaveDataset")); opBuilder.addInput(selectorInputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(dataInputDatasets)); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterByLastComponentDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterByLastComponentDataset.java index d6c2b576d7e..ea090e46fc6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterByLastComponentDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterByLastComponentDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -45,11 +44,11 @@ public final class FilterByLastComponentDataset extends RawOp implements Operand * @return a new instance of FilterByLastComponentDataset */ @Endpoint(describeByClass = true) - public static FilterByLastComponentDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static FilterByLastComponentDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("FilterByLastComponentDataset", scope.makeOpName("FilterByLastComponentDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IgnoreErrorsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IgnoreErrorsDataset.java index 7ffca90d150..08aaf2228a1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IgnoreErrorsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IgnoreErrorsDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -45,11 +44,11 @@ public final class IgnoreErrorsDataset extends RawOp implements Operand { * @return a new instance of IgnoreErrorsDataset */ @Endpoint(describeByClass = true) - public static IgnoreErrorsDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static IgnoreErrorsDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("IgnoreErrorsDataset", scope.makeOpName("IgnoreErrorsDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/Iterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/Iterator.java index 45030001714..860c089fd40 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/Iterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/Iterator.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,12 +45,12 @@ public final class Iterator extends RawOp implements Operand { * @return a new instance of Iterator */ @Endpoint(describeByClass = true) - public static Iterator create(Scope scope, String sharedName, String container, List> outputTypes, List outputShapes) { + public static Iterator create(Scope scope, String sharedName, String container, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("IteratorV2", scope.makeOpName("Iterator")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("shared_name", sharedName); opBuilder.setAttr("container", container); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorFromStringHandle.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorFromStringHandle.java index 8264419635a..db50d961930 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorFromStringHandle.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorFromStringHandle.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -64,11 +63,11 @@ private Options() { * @return a new instance of IteratorFromStringHandle */ @Endpoint(describeByClass = true) - public static IteratorFromStringHandle create(Scope scope, Operand stringHandle, List> outputTypes, Options... options) { + public static IteratorFromStringHandle create(Scope scope, Operand stringHandle, List> outputTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("IteratorFromStringHandleV2", scope.makeOpName("IteratorFromStringHandle")); opBuilder.addInput(stringHandle.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNext.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNext.java index b7be406e7bd..2a02a9dc1ea 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNext.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNext.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -48,11 +47,11 @@ public final class IteratorGetNext extends RawOp implements Iterable iterator, List> outputTypes, List outputShapes) { + public static IteratorGetNext create(Scope scope, Operand iterator, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("IteratorGetNext", scope.makeOpName("IteratorGetNext")); opBuilder.addInput(iterator.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextAsOptional.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextAsOptional.java index 682b7ba6f35..73d16c7b56f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextAsOptional.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextAsOptional.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,11 +45,11 @@ public final class IteratorGetNextAsOptional extends RawOp implements Operand iterator, List> outputTypes, List outputShapes) { + public static IteratorGetNextAsOptional create(Scope scope, Operand iterator, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("IteratorGetNextAsOptional", scope.makeOpName("IteratorGetNextAsOptional")); opBuilder.addInput(iterator.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextSync.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextSync.java index 3cb4f072307..27a088fdf67 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextSync.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextSync.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -53,11 +52,11 @@ public final class IteratorGetNextSync extends RawOp implements Iterable iterator, List> outputTypes, List outputShapes) { + public static IteratorGetNextSync create(Scope scope, Operand iterator, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("IteratorGetNextSync", scope.makeOpName("IteratorGetNextSync")); opBuilder.addInput(iterator.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LMDBDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LMDBDataset.java index d0057d60323..042e866f469 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LMDBDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LMDBDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -58,11 +57,11 @@ public final class LMDBDataset extends RawOp implements Operand { * @return a new instance of LMDBDataset */ @Endpoint(describeByClass = true) - public static LMDBDataset create(Scope scope, Operand filenames, List> outputTypes, List outputShapes) { + public static LMDBDataset create(Scope scope, Operand filenames, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("LMDBDataset", scope.makeOpName("LMDBDataset")); opBuilder.addInput(filenames.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LatencyStatsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LatencyStatsDataset.java index 731b70867a7..29b9ba70579 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LatencyStatsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LatencyStatsDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,12 +46,12 @@ public final class LatencyStatsDataset extends RawOp implements Operand { * @return a new instance of LatencyStatsDataset */ @Endpoint(describeByClass = true) - public static LatencyStatsDataset create(Scope scope, Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { + public static LatencyStatsDataset create(Scope scope, Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("LatencyStatsDataset", scope.makeOpName("LatencyStatsDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(tag.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LeakyReluGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LeakyReluGrad.java index 3fb21bd4a6e..e0c06a9b3b6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LeakyReluGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LeakyReluGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes rectified linear gradients for a LeakyRelu operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MaxIntraOpParallelismDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MaxIntraOpParallelismDataset.java index d7c67924b55..390c6d281ac 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MaxIntraOpParallelismDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MaxIntraOpParallelismDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,12 +46,12 @@ public final class MaxIntraOpParallelismDataset extends RawOp implements Operand * @return a new instance of MaxIntraOpParallelismDataset */ @Endpoint(describeByClass = true) - public static MaxIntraOpParallelismDataset create(Scope scope, Operand inputDataset, Operand maxIntraOpParallelism, List> outputTypes, List outputShapes) { + public static MaxIntraOpParallelismDataset create(Scope scope, Operand inputDataset, Operand maxIntraOpParallelism, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("MaxIntraOpParallelismDataset", scope.makeOpName("MaxIntraOpParallelismDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(maxIntraOpParallelism.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ModelDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ModelDataset.java index 53b6f10204d..fa1acc9e5b2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ModelDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ModelDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -76,11 +75,11 @@ private Options() { * @return a new instance of ModelDataset */ @Endpoint(describeByClass = true) - public static ModelDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes, Options... options) { + public static ModelDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ModelDataset", scope.makeOpName("ModelDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIterator.java index a3dfeeb5665..cb223414fec 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIterator.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,7 +48,7 @@ public final class MultiDeviceIterator extends RawOp implements Operand { * @return a new instance of MultiDeviceIterator */ @Endpoint(describeByClass = true) - public static MultiDeviceIterator create(Scope scope, List devices, String sharedName, String container, List> outputTypes, List outputShapes) { + public static MultiDeviceIterator create(Scope scope, List devices, String sharedName, String container, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("MultiDeviceIterator", scope.makeOpName("MultiDeviceIterator")); opBuilder = scope.applyControlDependencies(opBuilder); String[] devicesArray = new String[devices.size()]; @@ -59,7 +58,7 @@ public static MultiDeviceIterator create(Scope scope, List devices, Stri opBuilder.setAttr("devices", devicesArray); opBuilder.setAttr("shared_name", sharedName); opBuilder.setAttr("container", container); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorFromStringHandle.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorFromStringHandle.java index a9b0b8ec07e..af6394bb70d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorFromStringHandle.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorFromStringHandle.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -65,11 +64,11 @@ private Options() { * @return a new instance of MultiDeviceIteratorFromStringHandle */ @Endpoint(describeByClass = true) - public static MultiDeviceIteratorFromStringHandle create(Scope scope, Operand stringHandle, List> outputTypes, Options... options) { + public static MultiDeviceIteratorFromStringHandle create(Scope scope, Operand stringHandle, List> outputTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MultiDeviceIteratorFromStringHandle", scope.makeOpName("MultiDeviceIteratorFromStringHandle")); opBuilder.addInput(stringHandle.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorGetNextFromShard.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorGetNextFromShard.java index 4781bc2d56e..0c0e79f7c31 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorGetNextFromShard.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorGetNextFromShard.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -51,13 +50,13 @@ public final class MultiDeviceIteratorGetNextFromShard extends RawOp implements * @return a new instance of MultiDeviceIteratorGetNextFromShard */ @Endpoint(describeByClass = true) - public static MultiDeviceIteratorGetNextFromShard create(Scope scope, Operand multiDeviceIterator, Operand shardNum, Operand incarnationId, List> outputTypes, List outputShapes) { + public static MultiDeviceIteratorGetNextFromShard create(Scope scope, Operand multiDeviceIterator, Operand shardNum, Operand incarnationId, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("MultiDeviceIteratorGetNextFromShard", scope.makeOpName("MultiDeviceIteratorGetNextFromShard")); opBuilder.addInput(multiDeviceIterator.asOutput()); opBuilder.addInput(shardNum.asOutput()); opBuilder.addInput(incarnationId.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/NonSerializableDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/NonSerializableDataset.java index 34043a14f9d..52cdcb8ce6f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/NonSerializableDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/NonSerializableDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -44,11 +43,11 @@ public final class NonSerializableDataset extends RawOp implements Operand inputDataset, List> outputTypes, List outputShapes) { + public static NonSerializableDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("NonSerializableDataset", scope.makeOpName("NonSerializableDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDataset.java index 73e27b3ffc9..7025879a980 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -69,12 +68,12 @@ private Options() { * @return a new instance of OptimizeDataset */ @Endpoint(describeByClass = true) - public static OptimizeDataset create(Scope scope, Operand inputDataset, Operand optimizations, List> outputTypes, List outputShapes, Options... options) { + public static OptimizeDataset create(Scope scope, Operand inputDataset, Operand optimizations, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OptimizeDataset", scope.makeOpName("OptimizeDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(optimizations.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptionalGetValue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptionalGetValue.java index 8d3be67da03..97a6bc8aea9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptionalGetValue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptionalGetValue.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -48,11 +47,11 @@ public final class OptionalGetValue extends RawOp implements Iterable optional, List> outputTypes, List outputShapes) { + public static OptionalGetValue create(Scope scope, Operand optional, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("OptionalGetValue", scope.makeOpName("OptionalGetValue")); opBuilder.addInput(optional.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrefetchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrefetchDataset.java index 78995ad6d20..0da54c538a0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrefetchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrefetchDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -77,12 +76,12 @@ private Options() { * @return a new instance of PrefetchDataset */ @Endpoint(describeByClass = true) - public static PrefetchDataset create(Scope scope, Operand inputDataset, Operand bufferSize, List> outputTypes, List outputShapes, Options... options) { + public static PrefetchDataset create(Scope scope, Operand inputDataset, Operand bufferSize, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("PrefetchDataset", scope.makeOpName("PrefetchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(bufferSize.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrivateThreadPoolDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrivateThreadPoolDataset.java index b06618ce07a..035a6402a74 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrivateThreadPoolDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrivateThreadPoolDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,12 +46,12 @@ public final class PrivateThreadPoolDataset extends RawOp implements Operand inputDataset, Operand numThreads, List> outputTypes, List outputShapes) { + public static PrivateThreadPoolDataset create(Scope scope, Operand inputDataset, Operand numThreads, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("PrivateThreadPoolDataset", scope.makeOpName("PrivateThreadPoolDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numThreads.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RandomDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RandomDataset.java index 2b944e789dc..577e512c476 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RandomDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RandomDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -60,12 +59,12 @@ public final class RandomDataset extends RawOp implements Operand { * @return a new instance of RandomDataset */ @Endpoint(describeByClass = true) - public static RandomDataset create(Scope scope, Operand seed, Operand seed2, List> outputTypes, List outputShapes) { + public static RandomDataset create(Scope scope, Operand seed, Operand seed2, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("RandomDataset", scope.makeOpName("RandomDataset")); opBuilder.addInput(seed.asOutput()); opBuilder.addInput(seed2.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RangeDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RangeDataset.java index bc83590e03e..1387e2582e6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RangeDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RangeDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,13 +48,13 @@ public final class RangeDataset extends RawOp implements Operand { * @return a new instance of RangeDataset */ @Endpoint(describeByClass = true) - public static RangeDataset create(Scope scope, Operand start, Operand stop, Operand step, List> outputTypes, List outputShapes) { + public static RangeDataset create(Scope scope, Operand start, Operand stop, Operand step, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("RangeDataset", scope.makeOpName("RangeDataset")); opBuilder.addInput(start.asOutput()); opBuilder.addInput(stop.asOutput()); opBuilder.addInput(step.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDataset.java index b5a21c7b6b8..d29ba3dd541 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -72,12 +71,12 @@ private Options() { * @return a new instance of RebatchDataset */ @Endpoint(describeByClass = true) - public static RebatchDataset create(Scope scope, Operand inputDataset, Operand numReplicas, List> outputTypes, List outputShapes, Options... options) { + public static RebatchDataset create(Scope scope, Operand inputDataset, Operand numReplicas, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("RebatchDataset", scope.makeOpName("RebatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numReplicas.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RepeatDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RepeatDataset.java index b832025853c..b12001c2439 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RepeatDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RepeatDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,12 +48,12 @@ public final class RepeatDataset extends RawOp implements Operand { * @return a new instance of RepeatDataset */ @Endpoint(describeByClass = true) - public static RepeatDataset create(Scope scope, Operand inputDataset, Operand count, List> outputTypes, List outputShapes) { + public static RepeatDataset create(Scope scope, Operand inputDataset, Operand count, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("RepeatDataset", scope.makeOpName("RepeatDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(count.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SamplingDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SamplingDataset.java index 7876fe27587..716c2c03128 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SamplingDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SamplingDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -57,14 +56,14 @@ public final class SamplingDataset extends RawOp implements Operand { * @return a new instance of SamplingDataset */ @Endpoint(describeByClass = true) - public static SamplingDataset create(Scope scope, Operand inputDataset, Operand rate, Operand seed, Operand seed2, List> outputTypes, List outputShapes) { + public static SamplingDataset create(Scope scope, Operand inputDataset, Operand rate, Operand seed, Operand seed2, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("SamplingDataset", scope.makeOpName("SamplingDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(rate.asOutput()); opBuilder.addInput(seed.asOutput()); opBuilder.addInput(seed2.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SetStatsAggregatorDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SetStatsAggregatorDataset.java index 5181c0b1f71..a107e2198b4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SetStatsAggregatorDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SetStatsAggregatorDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -48,14 +47,14 @@ public final class SetStatsAggregatorDataset extends RawOp implements Operand inputDataset, Operand statsAggregator, Operand tag, Operand counterPrefix, List> outputTypes, List outputShapes) { + public static SetStatsAggregatorDataset create(Scope scope, Operand inputDataset, Operand statsAggregator, Operand tag, Operand counterPrefix, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("SetStatsAggregatorDataset", scope.makeOpName("SetStatsAggregatorDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(statsAggregator.asOutput()); opBuilder.addInput(tag.asOutput()); opBuilder.addInput(counterPrefix.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShardDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShardDataset.java index 62e18f3fadf..98e043f8e5f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShardDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShardDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -68,13 +67,13 @@ private Options() { * @return a new instance of ShardDataset */ @Endpoint(describeByClass = true) - public static ShardDataset create(Scope scope, Operand inputDataset, Operand numShards, Operand index, List> outputTypes, List outputShapes, Options... options) { + public static ShardDataset create(Scope scope, Operand inputDataset, Operand numShards, Operand index, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ShardDataset", scope.makeOpName("ShardDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numShards.asOutput()); opBuilder.addInput(index.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java index c5703e8e85c..fa5538deff9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -70,7 +69,7 @@ private Options() { * @return a new instance of ShuffleAndRepeatDataset */ @Endpoint(describeByClass = true) - public static ShuffleAndRepeatDataset create(Scope scope, Operand inputDataset, Operand bufferSize, Operand seed, Operand seed2, Operand count, Operand seedGenerator, List> outputTypes, List outputShapes, Options... options) { + public static ShuffleAndRepeatDataset create(Scope scope, Operand inputDataset, Operand bufferSize, Operand seed, Operand seed2, Operand count, Operand seedGenerator, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ShuffleAndRepeatDatasetV2", scope.makeOpName("ShuffleAndRepeatDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(bufferSize.asOutput()); @@ -79,7 +78,7 @@ public static ShuffleAndRepeatDataset create(Scope scope, Operand inputDatase opBuilder.addInput(count.asOutput()); opBuilder.addInput(seedGenerator.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java index 3dd522e319c..e8682dec91c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -69,7 +68,7 @@ private Options() { * @return a new instance of ShuffleDataset */ @Endpoint(describeByClass = true) - public static ShuffleDataset create(Scope scope, Operand inputDataset, Operand bufferSize, Operand seed, Operand seed2, Operand seedGenerator, List> outputTypes, List outputShapes, Options... options) { + public static ShuffleDataset create(Scope scope, Operand inputDataset, Operand bufferSize, Operand seed, Operand seed2, Operand seedGenerator, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ShuffleDatasetV3", scope.makeOpName("ShuffleDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(bufferSize.asOutput()); @@ -77,7 +76,7 @@ public static ShuffleDataset create(Scope scope, Operand inputDataset, Operan opBuilder.addInput(seed2.asOutput()); opBuilder.addInput(seedGenerator.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SkipDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SkipDataset.java index 4e409b5a4ea..2d7171854a2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SkipDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SkipDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,12 +48,12 @@ public final class SkipDataset extends RawOp implements Operand { * @return a new instance of SkipDataset */ @Endpoint(describeByClass = true) - public static SkipDataset create(Scope scope, Operand inputDataset, Operand count, List> outputTypes, List outputShapes) { + public static SkipDataset create(Scope scope, Operand inputDataset, Operand count, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("SkipDataset", scope.makeOpName("SkipDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(count.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SleepDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SleepDataset.java index 67605099106..5feb04bb601 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SleepDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SleepDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,12 +45,12 @@ public final class SleepDataset extends RawOp implements Operand { * @return a new instance of SleepDataset */ @Endpoint(describeByClass = true) - public static SleepDataset create(Scope scope, Operand inputDataset, Operand sleepMicroseconds, List> outputTypes, List outputShapes) { + public static SleepDataset create(Scope scope, Operand inputDataset, Operand sleepMicroseconds, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("SleepDataset", scope.makeOpName("SleepDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(sleepMicroseconds.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SlidingWindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SlidingWindowDataset.java index 89f440efd77..5f4fc56c523 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SlidingWindowDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SlidingWindowDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -52,14 +51,14 @@ public final class SlidingWindowDataset extends RawOp implements Operand * @return a new instance of SlidingWindowDataset */ @Endpoint(describeByClass = true) - public static SlidingWindowDataset create(Scope scope, Operand inputDataset, Operand windowSize, Operand windowShift, Operand windowStride, List> outputTypes, List outputShapes) { + public static SlidingWindowDataset create(Scope scope, Operand inputDataset, Operand windowSize, Operand windowShift, Operand windowStride, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("SlidingWindowDataset", scope.makeOpName("SlidingWindowDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(windowSize.asOutput()); opBuilder.addInput(windowShift.asOutput()); opBuilder.addInput(windowStride.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java deleted file mode 100644 index f6bc66e8297..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java +++ /dev/null @@ -1,376 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.data; - -import java.util.List; -import org.tensorflow.DataType; -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.types.TString; -import org.tensorflow.types.family.TType; - -/** - * Creates a dataset that will write to / read from a snapshot. - *

- * This dataset attempts to determine whether a valid snapshot exists at the - * `snapshot_path`, and reads from the snapshot in lieu of using `input_dataset`. - * If not, it will run the preprocessing pipeline as usual, and write out a - * snapshot of the data processed for future use. - */ -public final class SnapshotDataset extends RawOp implements Operand { - - /** - * Optional attributes for {@link org.tensorflow.op.data.SnapshotDataset} - */ - public static class Options { - - /** - * @param compression - */ - public Options compression(String compression) { - this.compression = compression; - return this; - } - - /** - * @param readerPathPrefix - */ - public Options readerPathPrefix(String readerPathPrefix) { - this.readerPathPrefix = readerPathPrefix; - return this; - } - - /** - * @param writerPathPrefix - */ - public Options writerPathPrefix(String writerPathPrefix) { - this.writerPathPrefix = writerPathPrefix; - return this; - } - - /** - * @param shardSizeBytes - */ - public Options shardSizeBytes(Long shardSizeBytes) { - this.shardSizeBytes = shardSizeBytes; - return this; - } - - /** - * @param pendingSnapshotExpirySeconds - */ - public Options pendingSnapshotExpirySeconds(Long pendingSnapshotExpirySeconds) { - this.pendingSnapshotExpirySeconds = pendingSnapshotExpirySeconds; - return this; - } - - /** - * @param numReaderThreads - */ - public Options numReaderThreads(Long numReaderThreads) { - this.numReaderThreads = numReaderThreads; - return this; - } - - /** - * @param readerBufferSize - */ - public Options readerBufferSize(Long readerBufferSize) { - this.readerBufferSize = readerBufferSize; - return this; - } - - /** - * @param numWriterThreads - */ - public Options numWriterThreads(Long numWriterThreads) { - this.numWriterThreads = numWriterThreads; - return this; - } - - /** - * @param writerBufferSize - */ - public Options writerBufferSize(Long writerBufferSize) { - this.writerBufferSize = writerBufferSize; - return this; - } - - /** - * @param shuffleOnRead - */ - public Options shuffleOnRead(Boolean shuffleOnRead) { - this.shuffleOnRead = shuffleOnRead; - return this; - } - - /** - * @param seed - */ - public Options seed(Long seed) { - this.seed = seed; - return this; - } - - /** - * @param seed2 - */ - public Options seed2(Long seed2) { - this.seed2 = seed2; - return this; - } - - /** - * @param mode - */ - public Options mode(String mode) { - this.mode = mode; - return this; - } - - /** - * @param snapshotName - */ - public Options snapshotName(String snapshotName) { - this.snapshotName = snapshotName; - return this; - } - - private String compression; - private String readerPathPrefix; - private String writerPathPrefix; - private Long shardSizeBytes; - private Long pendingSnapshotExpirySeconds; - private Long numReaderThreads; - private Long readerBufferSize; - private Long numWriterThreads; - private Long writerBufferSize; - private Boolean shuffleOnRead; - private Long seed; - private Long seed2; - private String mode; - private String snapshotName; - - private Options() { - } - } - - /** - * Factory method to create a class wrapping a new SnapshotDataset operation. - * - * @param scope current scope - * @param inputDataset A variant tensor representing the input dataset. - * @param path The path we should write snapshots to / read snapshots from. - * @param outputTypes - * @param outputShapes - * @param options carries optional attributes values - * @return a new instance of SnapshotDataset - */ - @Endpoint(describeByClass = true) - public static SnapshotDataset create(Scope scope, Operand inputDataset, Operand path, List> outputTypes, List outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("SnapshotDataset", scope.makeOpName("SnapshotDataset")); - opBuilder.addInput(inputDataset.asOutput()); - opBuilder.addInput(path.asOutput()); - opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); - Shape[] outputShapesArray = new Shape[outputShapes.size()]; - for (int i = 0; i < outputShapesArray.length; ++i) { - outputShapesArray[i] = outputShapes.get(i); - } - opBuilder.setAttr("output_shapes", outputShapesArray); - if (options != null) { - for (Options opts : options) { - if (opts.compression != null) { - opBuilder.setAttr("compression", opts.compression); - } - if (opts.readerPathPrefix != null) { - opBuilder.setAttr("reader_path_prefix", opts.readerPathPrefix); - } - if (opts.writerPathPrefix != null) { - opBuilder.setAttr("writer_path_prefix", opts.writerPathPrefix); - } - if (opts.shardSizeBytes != null) { - opBuilder.setAttr("shard_size_bytes", opts.shardSizeBytes); - } - if (opts.pendingSnapshotExpirySeconds != null) { - opBuilder.setAttr("pending_snapshot_expiry_seconds", opts.pendingSnapshotExpirySeconds); - } - if (opts.numReaderThreads != null) { - opBuilder.setAttr("num_reader_threads", opts.numReaderThreads); - } - if (opts.readerBufferSize != null) { - opBuilder.setAttr("reader_buffer_size", opts.readerBufferSize); - } - if (opts.numWriterThreads != null) { - opBuilder.setAttr("num_writer_threads", opts.numWriterThreads); - } - if (opts.writerBufferSize != null) { - opBuilder.setAttr("writer_buffer_size", opts.writerBufferSize); - } - if (opts.shuffleOnRead != null) { - opBuilder.setAttr("shuffle_on_read", opts.shuffleOnRead); - } - if (opts.seed != null) { - opBuilder.setAttr("seed", opts.seed); - } - if (opts.seed2 != null) { - opBuilder.setAttr("seed2", opts.seed2); - } - if (opts.mode != null) { - opBuilder.setAttr("mode", opts.mode); - } - if (opts.snapshotName != null) { - opBuilder.setAttr("snapshot_name", opts.snapshotName); - } - } - } - return new SnapshotDataset(opBuilder.build()); - } - - /** - * @param compression - */ - public static Options compression(String compression) { - return new Options().compression(compression); - } - - /** - * @param readerPathPrefix - */ - public static Options readerPathPrefix(String readerPathPrefix) { - return new Options().readerPathPrefix(readerPathPrefix); - } - - /** - * @param writerPathPrefix - */ - public static Options writerPathPrefix(String writerPathPrefix) { - return new Options().writerPathPrefix(writerPathPrefix); - } - - /** - * @param shardSizeBytes - */ - public static Options shardSizeBytes(Long shardSizeBytes) { - return new Options().shardSizeBytes(shardSizeBytes); - } - - /** - * @param pendingSnapshotExpirySeconds - */ - public static Options pendingSnapshotExpirySeconds(Long pendingSnapshotExpirySeconds) { - return new Options().pendingSnapshotExpirySeconds(pendingSnapshotExpirySeconds); - } - - /** - * @param numReaderThreads - */ - public static Options numReaderThreads(Long numReaderThreads) { - return new Options().numReaderThreads(numReaderThreads); - } - - /** - * @param readerBufferSize - */ - public static Options readerBufferSize(Long readerBufferSize) { - return new Options().readerBufferSize(readerBufferSize); - } - - /** - * @param numWriterThreads - */ - public static Options numWriterThreads(Long numWriterThreads) { - return new Options().numWriterThreads(numWriterThreads); - } - - /** - * @param writerBufferSize - */ - public static Options writerBufferSize(Long writerBufferSize) { - return new Options().writerBufferSize(writerBufferSize); - } - - /** - * @param shuffleOnRead - */ - public static Options shuffleOnRead(Boolean shuffleOnRead) { - return new Options().shuffleOnRead(shuffleOnRead); - } - - /** - * @param seed - */ - public static Options seed(Long seed) { - return new Options().seed(seed); - } - - /** - * @param seed2 - */ - public static Options seed2(Long seed2) { - return new Options().seed2(seed2); - } - - /** - * @param mode - */ - public static Options mode(String mode) { - return new Options().mode(mode); - } - - /** - * @param snapshotName - */ - public static Options snapshotName(String snapshotName) { - return new Options().snapshotName(snapshotName); - } - - /** - */ - public Output handle() { - return handle; - } - - @Override - @SuppressWarnings("unchecked") - public Output asOutput() { - return (Output) handle; - } - - /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "SnapshotDataset"; - - private Output handle; - - private SnapshotDataset(Operation operation) { - super(operation); - int outputIdx = 0; - handle = operation.output(outputIdx++); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SqlDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SqlDataset.java index 59a0a47cd4d..cb7140de450 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SqlDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SqlDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -48,13 +47,13 @@ public final class SqlDataset extends RawOp implements Operand { * @return a new instance of SqlDataset */ @Endpoint(describeByClass = true) - public static SqlDataset create(Scope scope, Operand driverName, Operand dataSourceName, Operand query, List> outputTypes, List outputShapes) { + public static SqlDataset create(Scope scope, Operand driverName, Operand dataSourceName, Operand query, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("SqlDataset", scope.makeOpName("SqlDataset")); opBuilder.addInput(driverName.asOutput()); opBuilder.addInput(dataSourceName.asOutput()); opBuilder.addInput(query.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeDataset.java index 3229f055924..8c12d56db8a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -50,12 +49,12 @@ public final class TakeDataset extends RawOp implements Operand { * @return a new instance of TakeDataset */ @Endpoint(describeByClass = true) - public static TakeDataset create(Scope scope, Operand inputDataset, Operand count, List> outputTypes, List outputShapes) { + public static TakeDataset create(Scope scope, Operand inputDataset, Operand count, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("TakeDataset", scope.makeOpName("TakeDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(count.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ThreadPoolDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ThreadPoolDataset.java index 0b1ab263ffa..9cc2ad12bba 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ThreadPoolDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ThreadPoolDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,12 +45,12 @@ public final class ThreadPoolDataset extends RawOp implements Operand { * @return a new instance of ThreadPoolDataset */ @Endpoint(describeByClass = true) - public static ThreadPoolDataset create(Scope scope, Operand inputDataset, Operand threadPool, List> outputTypes, List outputShapes) { + public static ThreadPoolDataset create(Scope scope, Operand inputDataset, Operand threadPool, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ThreadPoolDataset", scope.makeOpName("ThreadPoolDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(threadPool.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnbatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnbatchDataset.java index 1bca9b693cc..0ead7ff2ef1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnbatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnbatchDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -45,11 +44,11 @@ public final class UnbatchDataset extends RawOp implements Operand { * @return a new instance of UnbatchDataset */ @Endpoint(describeByClass = true) - public static UnbatchDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static UnbatchDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("UnbatchDataset", scope.makeOpName("UnbatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UniqueDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UniqueDataset.java index 817519ef6f8..f47548f8dad 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UniqueDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UniqueDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -45,11 +44,11 @@ public final class UniqueDataset extends RawOp implements Operand { * @return a new instance of UniqueDataset */ @Endpoint(describeByClass = true) - public static UniqueDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static UniqueDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("UniqueDataset", scope.makeOpName("UniqueDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WindowDataset.java index d5edcc917d8..332338991f5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WindowDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WindowDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -98,7 +97,7 @@ public final class WindowDataset extends RawOp implements Operand { * @return a new instance of WindowDataset */ @Endpoint(describeByClass = true) - public static WindowDataset create(Scope scope, Operand inputDataset, Operand size, Operand shift, Operand stride, Operand dropRemainder, List> outputTypes, List outputShapes) { + public static WindowDataset create(Scope scope, Operand inputDataset, Operand size, Operand shift, Operand stride, Operand dropRemainder, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("WindowDataset", scope.makeOpName("WindowDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(size.asOutput()); @@ -106,7 +105,7 @@ public static WindowDataset create(Scope scope, Operand inputDataset, Operand opBuilder.addInput(stride.asOutput()); opBuilder.addInput(dropRemainder.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ZipDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ZipDataset.java index d4f54946f5f..d5414b73ecf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ZipDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ZipDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -53,11 +52,11 @@ public final class ZipDataset extends RawOp implements Operand { * @return a new instance of ZipDataset */ @Endpoint(describeByClass = true) - public static ZipDataset create(Scope scope, Iterable> inputDatasets, List> outputTypes, List outputShapes) { + public static ZipDataset create(Scope scope, Iterable> inputDatasets, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ZipDataset", scope.makeOpName("ZipDataset")); opBuilder.addInputList(Operands.asOutputs(inputDatasets)); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertCardinalityDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertCardinalityDataset.java index 96c02bf5fc1..6e3694983d4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertCardinalityDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertCardinalityDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,12 +45,12 @@ public final class AssertCardinalityDataset extends RawOp implements Operand inputDataset, Operand cardinality, List> outputTypes, List outputShapes) { + public static AssertCardinalityDataset create(Scope scope, Operand inputDataset, Operand cardinality, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("AssertCardinalityDataset", scope.makeOpName("AssertCardinalityDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(cardinality.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertNextDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertNextDataset.java index cd0c5300df6..f4f0b763dc5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertNextDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertNextDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,12 +45,12 @@ public final class AssertNextDataset extends RawOp implements Operand { * @return a new instance of AssertNextDataset */ @Endpoint(describeByClass = true) - public static AssertNextDataset create(Scope scope, Operand inputDataset, Operand transformations, List> outputTypes, List outputShapes) { + public static AssertNextDataset create(Scope scope, Operand inputDataset, Operand transformations, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalAssertNextDataset", scope.makeOpName("AssertNextDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(transformations.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AutoShardDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AutoShardDataset.java index 3c1eb053091..d818cdd4175 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AutoShardDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AutoShardDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -76,13 +75,13 @@ private Options() { * @return a new instance of AutoShardDataset */ @Endpoint(describeByClass = true) - public static AutoShardDataset create(Scope scope, Operand inputDataset, Operand numWorkers, Operand index, List> outputTypes, List outputShapes, Options... options) { + public static AutoShardDataset create(Scope scope, Operand inputDataset, Operand numWorkers, Operand index, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalAutoShardDataset", scope.makeOpName("AutoShardDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numWorkers.asOutput()); opBuilder.addInput(index.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/BytesProducedStatsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/BytesProducedStatsDataset.java index 7fe1290b130..dcf5179b0c9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/BytesProducedStatsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/BytesProducedStatsDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,12 +46,12 @@ public final class BytesProducedStatsDataset extends RawOp implements Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { + public static BytesProducedStatsDataset create(Scope scope, Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalBytesProducedStatsDataset", scope.makeOpName("BytesProducedStatsDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(tag.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ChooseFastestDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ChooseFastestDataset.java index f731d948490..aee0c946f14 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ChooseFastestDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ChooseFastestDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,12 +45,12 @@ public final class ChooseFastestDataset extends RawOp implements Operand * @return a new instance of ChooseFastestDataset */ @Endpoint(describeByClass = true) - public static ChooseFastestDataset create(Scope scope, Iterable> inputDatasets, Long numExperiments, List> outputTypes, List outputShapes) { + public static ChooseFastestDataset create(Scope scope, Iterable> inputDatasets, Long numExperiments, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalChooseFastestDataset", scope.makeOpName("ChooseFastestDataset")); opBuilder.addInputList(Operands.asOutputs(inputDatasets)); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("num_experiments", numExperiments); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java index b3e853f9de4..3330b2163e3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -73,7 +72,7 @@ private Options() { * @return a new instance of DataServiceDataset */ @Endpoint(describeByClass = true) - public static DataServiceDataset create(Scope scope, Operand datasetId, Operand processingMode, Operand address, Operand protocol, Operand jobName, Operand maxOutstandingRequests, Operand iterationCounter, List> outputTypes, List outputShapes, Options... options) { + public static DataServiceDataset create(Scope scope, Operand datasetId, Operand processingMode, Operand address, Operand protocol, Operand jobName, Operand maxOutstandingRequests, Operand iterationCounter, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("DataServiceDataset", scope.makeOpName("DataServiceDataset")); opBuilder.addInput(datasetId.asOutput()); opBuilder.addInput(processingMode.asOutput()); @@ -83,7 +82,7 @@ public static DataServiceDataset create(Scope scope, Operand datasetId, opBuilder.addInput(maxOutstandingRequests.asOutput()); opBuilder.addInput(iterationCounter.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DenseToSparseBatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DenseToSparseBatchDataset.java index d464f422836..ff2910a2e38 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DenseToSparseBatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DenseToSparseBatchDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -51,13 +50,13 @@ public final class DenseToSparseBatchDataset extends RawOp implements Operand inputDataset, Operand batchSize, Operand rowShape, List> outputTypes, List outputShapes) { + public static DenseToSparseBatchDataset create(Scope scope, Operand inputDataset, Operand batchSize, Operand rowShape, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalDenseToSparseBatchDataset", scope.makeOpName("DenseToSparseBatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(batchSize.asOutput()); opBuilder.addInput(rowShape.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DirectedInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DirectedInterleaveDataset.java index 63a06a16201..50ee8a37fcb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DirectedInterleaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DirectedInterleaveDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,12 +48,12 @@ public final class DirectedInterleaveDataset extends RawOp implements Operand selectorInputDataset, Iterable> dataInputDatasets, List> outputTypes, List outputShapes) { + public static DirectedInterleaveDataset create(Scope scope, Operand selectorInputDataset, Iterable> dataInputDatasets, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalDirectedInterleaveDataset", scope.makeOpName("DirectedInterleaveDataset")); opBuilder.addInput(selectorInputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(dataInputDatasets)); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/IgnoreErrorsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/IgnoreErrorsDataset.java index 22d734831da..524814790f5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/IgnoreErrorsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/IgnoreErrorsDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -45,11 +44,11 @@ public final class IgnoreErrorsDataset extends RawOp implements Operand { * @return a new instance of IgnoreErrorsDataset */ @Endpoint(describeByClass = true) - public static IgnoreErrorsDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static IgnoreErrorsDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalIgnoreErrorsDataset", scope.makeOpName("IgnoreErrorsDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LatencyStatsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LatencyStatsDataset.java index ad901295c28..cd9f0a2c498 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LatencyStatsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LatencyStatsDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,12 +46,12 @@ public final class LatencyStatsDataset extends RawOp implements Operand { * @return a new instance of LatencyStatsDataset */ @Endpoint(describeByClass = true) - public static LatencyStatsDataset create(Scope scope, Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { + public static LatencyStatsDataset create(Scope scope, Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalLatencyStatsDataset", scope.makeOpName("LatencyStatsDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(tag.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LmdbDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LmdbDataset.java index e6c489833af..3ec97c34b61 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LmdbDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LmdbDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -45,11 +44,11 @@ public final class LmdbDataset extends RawOp implements Operand { * @return a new instance of LmdbDataset */ @Endpoint(describeByClass = true) - public static LmdbDataset create(Scope scope, Operand filenames, List> outputTypes, List outputShapes) { + public static LmdbDataset create(Scope scope, Operand filenames, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalLMDBDataset", scope.makeOpName("LmdbDataset")); opBuilder.addInput(filenames.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MaxIntraOpParallelismDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MaxIntraOpParallelismDataset.java index fb69729fab0..88d924e5e43 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MaxIntraOpParallelismDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MaxIntraOpParallelismDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,12 +46,12 @@ public final class MaxIntraOpParallelismDataset extends RawOp implements Operand * @return a new instance of MaxIntraOpParallelismDataset */ @Endpoint(describeByClass = true) - public static MaxIntraOpParallelismDataset create(Scope scope, Operand inputDataset, Operand maxIntraOpParallelism, List> outputTypes, List outputShapes) { + public static MaxIntraOpParallelismDataset create(Scope scope, Operand inputDataset, Operand maxIntraOpParallelism, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalMaxIntraOpParallelismDataset", scope.makeOpName("MaxIntraOpParallelismDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(maxIntraOpParallelism.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/NonSerializableDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/NonSerializableDataset.java index 9f07f37f804..45866ff08b4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/NonSerializableDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/NonSerializableDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -44,11 +43,11 @@ public final class NonSerializableDataset extends RawOp implements Operand inputDataset, List> outputTypes, List outputShapes) { + public static NonSerializableDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalNonSerializableDataset", scope.makeOpName("NonSerializableDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParseExampleDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParseExampleDataset.java index 4c24cf99558..c4602914a42 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParseExampleDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParseExampleDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -102,7 +101,7 @@ private Options() { * @return a new instance of ParseExampleDataset */ @Endpoint(describeByClass = true) - public static ParseExampleDataset create(Scope scope, Operand inputDataset, Operand numParallelCalls, Iterable> denseDefaults, List sparseKeys, List denseKeys, List> sparseTypes, List denseShapes, List> outputTypes, List outputShapes, List> raggedValueTypes, List> raggedSplitTypes, Options... options) { + public static ParseExampleDataset create(Scope scope, Operand inputDataset, Operand numParallelCalls, Iterable> denseDefaults, List sparseKeys, List denseKeys, List> sparseTypes, List denseShapes, List> outputTypes, List outputShapes, List> raggedValueTypes, List> raggedSplitTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ParseExampleDatasetV2", scope.makeOpName("ParseExampleDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numParallelCalls.asOutput()); @@ -118,7 +117,7 @@ public static ParseExampleDataset create(Scope scope, Operand inputDataset, O denseKeysArray[i] = denseKeys.get(i); } opBuilder.setAttr("dense_keys", denseKeysArray); - DataType[] sparseTypesArray = new DataType[sparseTypes.size()]; + Class[] sparseTypesArray = new Class[sparseTypes.size()]; for (int i = 0; i < sparseTypesArray.length; ++i) { sparseTypesArray[i] = sparseTypes.get(i); } @@ -128,7 +127,7 @@ public static ParseExampleDataset create(Scope scope, Operand inputDataset, O denseShapesArray[i] = denseShapes.get(i); } opBuilder.setAttr("dense_shapes", denseShapesArray); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } @@ -138,12 +137,12 @@ public static ParseExampleDataset create(Scope scope, Operand inputDataset, O outputShapesArray[i] = outputShapes.get(i); } opBuilder.setAttr("output_shapes", outputShapesArray); - DataType[] raggedValueTypesArray = new DataType[raggedValueTypes.size()]; + Class[] raggedValueTypesArray = new Class[raggedValueTypes.size()]; for (int i = 0; i < raggedValueTypesArray.length; ++i) { raggedValueTypesArray[i] = raggedValueTypes.get(i); } opBuilder.setAttr("ragged_value_types", raggedValueTypesArray); - DataType[] raggedSplitTypesArray = new DataType[raggedSplitTypes.size()]; + Class[] raggedSplitTypesArray = new Class[raggedSplitTypes.size()]; for (int i = 0; i < raggedSplitTypesArray.length; ++i) { raggedSplitTypesArray[i] = raggedSplitTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/PrivateThreadPoolDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/PrivateThreadPoolDataset.java index a52fd41140e..7d85832bc3e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/PrivateThreadPoolDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/PrivateThreadPoolDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,12 +46,12 @@ public final class PrivateThreadPoolDataset extends RawOp implements Operand inputDataset, Operand numThreads, List> outputTypes, List outputShapes) { + public static PrivateThreadPoolDataset create(Scope scope, Operand inputDataset, Operand numThreads, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalPrivateThreadPoolDataset", scope.makeOpName("PrivateThreadPoolDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numThreads.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RandomDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RandomDataset.java index 8f9109ab3d5..f2e06ee1b97 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RandomDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RandomDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,12 +48,12 @@ public final class RandomDataset extends RawOp implements Operand { * @return a new instance of RandomDataset */ @Endpoint(describeByClass = true) - public static RandomDataset create(Scope scope, Operand seed, Operand seed2, List> outputTypes, List outputShapes) { + public static RandomDataset create(Scope scope, Operand seed, Operand seed2, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalRandomDataset", scope.makeOpName("RandomDataset")); opBuilder.addInput(seed.asOutput()); opBuilder.addInput(seed2.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RebatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RebatchDataset.java index 6111166128e..c647ea32f5a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RebatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RebatchDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -72,12 +71,12 @@ private Options() { * @return a new instance of RebatchDataset */ @Endpoint(describeByClass = true) - public static RebatchDataset create(Scope scope, Operand inputDataset, Operand numReplicas, List> outputTypes, List outputShapes, Options... options) { + public static RebatchDataset create(Scope scope, Operand inputDataset, Operand numReplicas, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalRebatchDataset", scope.makeOpName("RebatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numReplicas.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SetStatsAggregatorDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SetStatsAggregatorDataset.java index e2535dcecd7..842c2fae29b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SetStatsAggregatorDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SetStatsAggregatorDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -48,14 +47,14 @@ public final class SetStatsAggregatorDataset extends RawOp implements Operand inputDataset, Operand statsAggregator, Operand tag, Operand counterPrefix, List> outputTypes, List outputShapes) { + public static SetStatsAggregatorDataset create(Scope scope, Operand inputDataset, Operand statsAggregator, Operand tag, Operand counterPrefix, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalSetStatsAggregatorDataset", scope.makeOpName("SetStatsAggregatorDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(statsAggregator.asOutput()); opBuilder.addInput(tag.asOutput()); opBuilder.addInput(counterPrefix.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SleepDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SleepDataset.java index 97ddeb96999..cfe2babce9e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SleepDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SleepDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,12 +45,12 @@ public final class SleepDataset extends RawOp implements Operand { * @return a new instance of SleepDataset */ @Endpoint(describeByClass = true) - public static SleepDataset create(Scope scope, Operand inputDataset, Operand sleepMicroseconds, List> outputTypes, List outputShapes) { + public static SleepDataset create(Scope scope, Operand inputDataset, Operand sleepMicroseconds, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalSleepDataset", scope.makeOpName("SleepDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(sleepMicroseconds.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SlidingWindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SlidingWindowDataset.java index b8dbded85e9..a7183757b4f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SlidingWindowDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SlidingWindowDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -52,14 +51,14 @@ public final class SlidingWindowDataset extends RawOp implements Operand * @return a new instance of SlidingWindowDataset */ @Endpoint(describeByClass = true) - public static SlidingWindowDataset create(Scope scope, Operand inputDataset, Operand windowSize, Operand windowShift, Operand windowStride, List> outputTypes, List outputShapes) { + public static SlidingWindowDataset create(Scope scope, Operand inputDataset, Operand windowSize, Operand windowShift, Operand windowStride, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalSlidingWindowDataset", scope.makeOpName("SlidingWindowDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(windowSize.asOutput()); opBuilder.addInput(windowShift.asOutput()); opBuilder.addInput(windowStride.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SqlDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SqlDataset.java index a68f46a5c08..c6bd7b4a433 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SqlDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SqlDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -48,13 +47,13 @@ public final class SqlDataset extends RawOp implements Operand { * @return a new instance of SqlDataset */ @Endpoint(describeByClass = true) - public static SqlDataset create(Scope scope, Operand driverName, Operand dataSourceName, Operand query, List> outputTypes, List outputShapes) { + public static SqlDataset create(Scope scope, Operand driverName, Operand dataSourceName, Operand query, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalSqlDataset", scope.makeOpName("SqlDataset")); opBuilder.addInput(driverName.asOutput()); opBuilder.addInput(dataSourceName.asOutput()); opBuilder.addInput(query.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ThreadPoolDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ThreadPoolDataset.java index e3fdc2b92cf..2d9a37dba6e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ThreadPoolDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ThreadPoolDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,12 +45,12 @@ public final class ThreadPoolDataset extends RawOp implements Operand { * @return a new instance of ThreadPoolDataset */ @Endpoint(describeByClass = true) - public static ThreadPoolDataset create(Scope scope, Operand inputDataset, Operand threadPool, List> outputTypes, List outputShapes) { + public static ThreadPoolDataset create(Scope scope, Operand inputDataset, Operand threadPool, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalThreadPoolDataset", scope.makeOpName("ThreadPoolDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(threadPool.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UnbatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UnbatchDataset.java index dfa611faa0e..f44f8f09fc2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UnbatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UnbatchDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -45,11 +44,11 @@ public final class UnbatchDataset extends RawOp implements Operand { * @return a new instance of UnbatchDataset */ @Endpoint(describeByClass = true) - public static UnbatchDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static UnbatchDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalUnbatchDataset", scope.makeOpName("UnbatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java index c5732154e94..79d4624779d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,11 +46,11 @@ public final class UncompressElement extends RawOp implements Iterable compressed, List> outputTypes, List outputShapes) { + public static UncompressElement create(Scope scope, Operand compressed, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("UncompressElement", scope.makeOpName("UncompressElement")); opBuilder.addInput(compressed.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UniqueDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UniqueDataset.java index eb213348e72..b05203c4542 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UniqueDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UniqueDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -45,11 +44,11 @@ public final class UniqueDataset extends RawOp implements Operand { * @return a new instance of UniqueDataset */ @Endpoint(describeByClass = true) - public static UniqueDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static UniqueDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalUniqueDataset", scope.makeOpName("UniqueDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; + Class[] outputTypesArray = new Class[outputTypes.size()]; for (int i = 0; i < outputTypesArray.length; ++i) { outputTypesArray[i] = outputTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/CheckNumerics.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/CheckNumerics.java index e26417c561d..44ec8ca8aaf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/CheckNumerics.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/CheckNumerics.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Checks a tensor for NaN, -Inf and +Inf values. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/DebugNumericsSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/DebugNumericsSummary.java index 85c20b1eef0..3b927a32cf6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/DebugNumericsSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/DebugNumericsSummary.java @@ -17,7 +17,6 @@ package org.tensorflow.op.debugging; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -132,7 +131,7 @@ private Options() { * @return a new instance of DebugNumericsSummary */ @Endpoint(describeByClass = true) - public static DebugNumericsSummary create(Scope scope, Operand input, DataType outputDtype, Options... options) { + public static DebugNumericsSummary create(Scope scope, Operand input, Class outputDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("DebugNumericSummaryV2", scope.makeOpName("DebugNumericsSummary")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -160,7 +159,7 @@ public static DebugNumericsSummary creat */ @Endpoint(describeByClass = true) public static DebugNumericsSummary create(Scope scope, Operand input, Options... options) { - return create(scope, input, TFloat32.DTYPE, options); + return create(scope, input, TFloat32.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Cast.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Cast.java index 49699f07c8c..66c97240952 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Cast.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Cast.java @@ -17,7 +17,6 @@ package org.tensorflow.op.dtypes; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -65,7 +64,7 @@ private Options() { * @return a new instance of Cast */ @Endpoint(describeByClass = true) - public static Cast create(Scope scope, Operand x, DataType DstT, Options... options) { + public static Cast create(Scope scope, Operand x, Class DstT, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Cast", scope.makeOpName("Cast")); opBuilder.addInput(x.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Complex.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Complex.java index 5e8918785a0..4d245490516 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Complex.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Complex.java @@ -17,7 +17,6 @@ package org.tensorflow.op.dtypes; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -62,7 +61,7 @@ public final class Complex extends RawOp implements Operand * @return a new instance of Complex */ @Endpoint(describeByClass = true) - public static Complex create(Scope scope, Operand real, Operand imag, DataType Tout) { + public static Complex create(Scope scope, Operand real, Operand imag, Class Tout) { OperationBuilder opBuilder = scope.env().opBuilder("Complex", scope.makeOpName("Complex")); opBuilder.addInput(real.asOutput()); opBuilder.addInput(imag.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustContrast.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustContrast.java index ab5b3f32e61..b33b91ed953 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustContrast.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustContrast.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Adjust the contrast of one or more images. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustHue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustHue.java index 2fe15aa50f1..7f1c320f052 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustHue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustHue.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Adjust the hue of one or more images. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustSaturation.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustSaturation.java index 03270949652..49a7f6ef6ae 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustSaturation.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustSaturation.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Adjust the saturation of one or more images. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResize.java index f0048c1014b..58e9ed905d1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResize.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Extracts crops from the input image tensor and resizes them. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradBoxes.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradBoxes.java index 67263c1e576..4c3c0f678c1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradBoxes.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradBoxes.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of the crop_and_resize op wrt the input boxes tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradImage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradImage.java index 6a6415f879d..32a7bc6d247 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradImage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradImage.java @@ -17,7 +17,6 @@ package org.tensorflow.op.image; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of the crop_and_resize op wrt the input image tensor. @@ -84,7 +82,7 @@ private Options() { * @return a new instance of CropAndResizeGradImage */ @Endpoint(describeByClass = true) - public static CropAndResizeGradImage create(Scope scope, Operand grads, Operand boxes, Operand boxInd, Operand imageSize, DataType T, Options... options) { + public static CropAndResizeGradImage create(Scope scope, Operand grads, Operand boxes, Operand boxInd, Operand imageSize, Class T, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("CropAndResizeGradImage", scope.makeOpName("CropAndResizeGradImage")); opBuilder.addInput(grads.asOutput()); opBuilder.addInput(boxes.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DecodePng.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DecodePng.java index 1d97bc54df7..dc917cb869a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DecodePng.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DecodePng.java @@ -17,7 +17,6 @@ package org.tensorflow.op.image; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,7 +28,6 @@ import org.tensorflow.types.TString; import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Decode a PNG-encoded image to a uint8 or uint16 tensor. @@ -92,7 +90,7 @@ private Options() { * @return a new instance of DecodePng */ @Endpoint(describeByClass = true) - public static DecodePng create(Scope scope, Operand contents, DataType dtype, Options... options) { + public static DecodePng create(Scope scope, Operand contents, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("DecodePng", scope.makeOpName("DecodePng")); opBuilder.addInput(contents.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -117,7 +115,7 @@ public static DecodePng create(Scope scope, Operand create(Scope scope, Operand contents, Options... options) { - return create(scope, contents, TUint8.DTYPE, options); + return create(scope, contents, TUint8.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DrawBoundingBoxes.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DrawBoundingBoxes.java index ba674376d07..edd3f0654d0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DrawBoundingBoxes.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DrawBoundingBoxes.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Draw bounding boxes on a batch of images. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/EncodePng.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/EncodePng.java index cf5b5a0ebc5..5944fbc9f87 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/EncodePng.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/EncodePng.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * PNG-encode an image. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractJpegShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractJpegShape.java index dfa8c4ea7b3..1f5ff7163fb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractJpegShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractJpegShape.java @@ -17,7 +17,6 @@ package org.tensorflow.op.image; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Extract the shape information of a JPEG-encoded image. @@ -51,7 +49,7 @@ public final class ExtractJpegShape extends RawOp implements * @return a new instance of ExtractJpegShape */ @Endpoint(describeByClass = true) - public static ExtractJpegShape create(Scope scope, Operand contents, DataType outputType) { + public static ExtractJpegShape create(Scope scope, Operand contents, Class outputType) { OperationBuilder opBuilder = scope.env().opBuilder("ExtractJpegShape", scope.makeOpName("ExtractJpegShape")); opBuilder.addInput(contents.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -68,7 +66,7 @@ public static ExtractJpegShape create(Scope scope, Operan */ @Endpoint(describeByClass = true) public static ExtractJpegShape create(Scope scope, Operand contents) { - return create(scope, contents, TInt32.DTYPE); + return create(scope, contents, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/HsvToRgb.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/HsvToRgb.java index fb752a83518..90b9c5ea776 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/HsvToRgb.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/HsvToRgb.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Convert one or more images from HSV to RGB. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ImageProjectiveTransformV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ImageProjectiveTransformV2.java index 3363d6a9804..a1c786cb619 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ImageProjectiveTransformV2.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ImageProjectiveTransformV2.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Applies the given transform to each of the images. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/NonMaxSuppression.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/NonMaxSuppression.java index 0745299e34f..919a761f5c4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/NonMaxSuppression.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/NonMaxSuppression.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Greedily selects a subset of bounding boxes in descending order of score, diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RandomCrop.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RandomCrop.java index 47d03d63187..a5e2de121b1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RandomCrop.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RandomCrop.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Randomly crop `image`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeArea.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeArea.java index 60d1e473ef0..895bdfac587 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeArea.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeArea.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Resize `images` to `size` using area interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubic.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubic.java index 2519eaa6ea9..7b3b9ffdcdd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubic.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubic.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Resize `images` to `size` using bicubic interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubicGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubicGrad.java index 96296cdf33f..e57cd749a72 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubicGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubicGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of bicubic interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinear.java index d0c28fed013..8b3bba2c9ae 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinear.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Resize `images` to `size` using bilinear interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinearGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinearGrad.java index b8aef51ba30..572f8693d4b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinearGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinearGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of bilinear interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighbor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighbor.java index dd7865841de..1da719fb569 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighbor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighbor.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Resize `images` to `size` using nearest neighbor interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighborGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighborGrad.java index 4897c0f9dc6..f0139dcca6c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighborGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighborGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of nearest neighbor interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RgbToHsv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RgbToHsv.java index acee7a6061e..0ac21df35a5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RgbToHsv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RgbToHsv.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Converts one or more images from RGB to HSV. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/SampleDistortedBoundingBox.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/SampleDistortedBoundingBox.java index 7f83b2ee84c..583da52ce0e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/SampleDistortedBoundingBox.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/SampleDistortedBoundingBox.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Generate a single randomly distorted bounding box for an image. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslate.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslate.java index 26bf544be2e..cf52ddc87d4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslate.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslate.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** */ diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslateGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslateGrad.java index 4b8d9b0fc4b..79d6c35a46e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslateGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslateGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodePaddedRaw.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodePaddedRaw.java index 820b0077ba5..87a4ed8300c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodePaddedRaw.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodePaddedRaw.java @@ -17,7 +17,6 @@ package org.tensorflow.op.io; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Reinterpret the bytes of a string as a vector of numbers. @@ -71,7 +69,7 @@ private Options() { * @return a new instance of DecodePaddedRaw */ @Endpoint(describeByClass = true) - public static DecodePaddedRaw create(Scope scope, Operand inputBytes, Operand fixedLength, DataType outType, Options... options) { + public static DecodePaddedRaw create(Scope scope, Operand inputBytes, Operand fixedLength, Class outType, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("DecodePaddedRaw", scope.makeOpName("DecodePaddedRaw")); opBuilder.addInput(inputBytes.asOutput()); opBuilder.addInput(fixedLength.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodeRaw.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodeRaw.java index dfdabfdb466..24b013325f3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodeRaw.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodeRaw.java @@ -17,7 +17,6 @@ package org.tensorflow.op.io; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -68,7 +67,7 @@ private Options() { * @return a new instance of DecodeRaw */ @Endpoint(describeByClass = true) - public static DecodeRaw create(Scope scope, Operand bytes, DataType outType, Options... options) { + public static DecodeRaw create(Scope scope, Operand bytes, Class outType, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("DecodeRaw", scope.makeOpName("DecodeRaw")); opBuilder.addInput(bytes.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DeserializeManySparse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DeserializeManySparse.java index 8582681c073..2a00c03b6a2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DeserializeManySparse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DeserializeManySparse.java @@ -17,7 +17,6 @@ package org.tensorflow.op.io; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -90,7 +89,7 @@ public final class DeserializeManySparse extends RawOp { * @return a new instance of DeserializeManySparse */ @Endpoint(describeByClass = true) - public static DeserializeManySparse create(Scope scope, Operand serializedSparse, DataType dtype) { + public static DeserializeManySparse create(Scope scope, Operand serializedSparse, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("DeserializeManySparse", scope.makeOpName("DeserializeManySparse")); opBuilder.addInput(serializedSparse.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/FifoQueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/FifoQueue.java index 842985b49e9..0e507d45e89 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/FifoQueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/FifoQueue.java @@ -18,7 +18,6 @@ package org.tensorflow.op.io; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -97,10 +96,10 @@ private Options() { * @return a new instance of FifoQueue */ @Endpoint(describeByClass = true) - public static FifoQueue create(Scope scope, List> componentTypes, Options... options) { + public static FifoQueue create(Scope scope, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("FIFOQueueV2", scope.makeOpName("FifoQueue")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; + Class[] componentTypesArray = new Class[componentTypes.size()]; for (int i = 0; i < componentTypesArray.length; ++i) { componentTypesArray[i] = componentTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PaddingFifoQueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PaddingFifoQueue.java index b72d8188038..9efc68bbdc7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PaddingFifoQueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PaddingFifoQueue.java @@ -18,7 +18,6 @@ package org.tensorflow.op.io; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -105,10 +104,10 @@ private Options() { * @return a new instance of PaddingFifoQueue */ @Endpoint(describeByClass = true) - public static PaddingFifoQueue create(Scope scope, List> componentTypes, Options... options) { + public static PaddingFifoQueue create(Scope scope, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("PaddingFIFOQueueV2", scope.makeOpName("PaddingFifoQueue")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; + Class[] componentTypesArray = new Class[componentTypes.size()]; for (int i = 0; i < componentTypesArray.length; ++i) { componentTypesArray[i] = componentTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseExample.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseExample.java index bd8f0f4a603..cbcf7e4deef 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseExample.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseExample.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -33,6 +32,7 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; /** * Transforms a vector of tf.Example protos (as strings) into typed tensors. @@ -97,7 +97,7 @@ public final class ParseExample extends RawOp { * @return a new instance of ParseExample */ @Endpoint(describeByClass = true) - public static ParseExample create(Scope scope, Operand serialized, Operand names, Operand sparseKeys, Operand denseKeys, Operand raggedKeys, Iterable> denseDefaults, Long numSparse, List> sparseTypes, List> raggedValueTypes, List> raggedSplitTypes, List denseShapes) { + public static ParseExample create(Scope scope, Operand serialized, Operand names, Operand sparseKeys, Operand denseKeys, Operand raggedKeys, Iterable> denseDefaults, Long numSparse, List> sparseTypes, List> raggedValueTypes, List> raggedSplitTypes, List denseShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ParseExampleV2", scope.makeOpName("ParseExample")); opBuilder.addInput(serialized.asOutput()); opBuilder.addInput(names.asOutput()); @@ -107,17 +107,17 @@ public static ParseExample create(Scope scope, Operand serialized, Oper opBuilder.addInputList(Operands.asOutputs(denseDefaults)); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("num_sparse", numSparse); - DataType[] sparseTypesArray = new DataType[sparseTypes.size()]; + Class[] sparseTypesArray = new Class[sparseTypes.size()]; for (int i = 0; i < sparseTypesArray.length; ++i) { sparseTypesArray[i] = sparseTypes.get(i); } opBuilder.setAttr("sparse_types", sparseTypesArray); - DataType[] raggedValueTypesArray = new DataType[raggedValueTypes.size()]; + Class[] raggedValueTypesArray = new Class[raggedValueTypes.size()]; for (int i = 0; i < raggedValueTypesArray.length; ++i) { raggedValueTypesArray[i] = raggedValueTypes.get(i); } opBuilder.setAttr("ragged_value_types", raggedValueTypesArray); - DataType[] raggedSplitTypesArray = new DataType[raggedSplitTypes.size()]; + Class[] raggedSplitTypesArray = new Class[raggedSplitTypes.size()]; for (int i = 0; i < raggedSplitTypesArray.length; ++i) { raggedSplitTypesArray[i] = raggedSplitTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSequenceExample.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSequenceExample.java index 1c10b138c55..df0db9fabc0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSequenceExample.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSequenceExample.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -34,6 +33,7 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; /** * Transforms a vector of tf.io.SequenceExample protos (as strings) into @@ -153,7 +153,7 @@ private Options() { * @return a new instance of ParseSequenceExample */ @Endpoint(describeByClass = true) - public static ParseSequenceExample create(Scope scope, Operand serialized, Operand debugName, Operand contextSparseKeys, Operand contextDenseKeys, Operand contextRaggedKeys, Operand featureListSparseKeys, Operand featureListDenseKeys, Operand featureListRaggedKeys, Operand featureListDenseMissingAssumedEmpty, Iterable> contextDenseDefaults, List> contextSparseTypes, List> contextRaggedValueTypes, List> contextRaggedSplitTypes, List> featureListDenseTypes, List> featureListSparseTypes, List> featureListRaggedValueTypes, List> featureListRaggedSplitTypes, Options... options) { + public static ParseSequenceExample create(Scope scope, Operand serialized, Operand debugName, Operand contextSparseKeys, Operand contextDenseKeys, Operand contextRaggedKeys, Operand featureListSparseKeys, Operand featureListDenseKeys, Operand featureListRaggedKeys, Operand featureListDenseMissingAssumedEmpty, Iterable> contextDenseDefaults, List> contextSparseTypes, List> contextRaggedValueTypes, List> contextRaggedSplitTypes, List> featureListDenseTypes, List> featureListSparseTypes, List> featureListRaggedValueTypes, List> featureListRaggedSplitTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ParseSequenceExampleV2", scope.makeOpName("ParseSequenceExample")); opBuilder.addInput(serialized.asOutput()); opBuilder.addInput(debugName.asOutput()); @@ -166,37 +166,37 @@ public static ParseSequenceExample create(Scope scope, Operand serializ opBuilder.addInput(featureListDenseMissingAssumedEmpty.asOutput()); opBuilder.addInputList(Operands.asOutputs(contextDenseDefaults)); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] contextSparseTypesArray = new DataType[contextSparseTypes.size()]; + Class[] contextSparseTypesArray = new Class[contextSparseTypes.size()]; for (int i = 0; i < contextSparseTypesArray.length; ++i) { contextSparseTypesArray[i] = contextSparseTypes.get(i); } opBuilder.setAttr("context_sparse_types", contextSparseTypesArray); - DataType[] contextRaggedValueTypesArray = new DataType[contextRaggedValueTypes.size()]; + Class[] contextRaggedValueTypesArray = new Class[contextRaggedValueTypes.size()]; for (int i = 0; i < contextRaggedValueTypesArray.length; ++i) { contextRaggedValueTypesArray[i] = contextRaggedValueTypes.get(i); } opBuilder.setAttr("context_ragged_value_types", contextRaggedValueTypesArray); - DataType[] contextRaggedSplitTypesArray = new DataType[contextRaggedSplitTypes.size()]; + Class[] contextRaggedSplitTypesArray = new Class[contextRaggedSplitTypes.size()]; for (int i = 0; i < contextRaggedSplitTypesArray.length; ++i) { contextRaggedSplitTypesArray[i] = contextRaggedSplitTypes.get(i); } opBuilder.setAttr("context_ragged_split_types", contextRaggedSplitTypesArray); - DataType[] featureListDenseTypesArray = new DataType[featureListDenseTypes.size()]; + Class[] featureListDenseTypesArray = new Class[featureListDenseTypes.size()]; for (int i = 0; i < featureListDenseTypesArray.length; ++i) { featureListDenseTypesArray[i] = featureListDenseTypes.get(i); } opBuilder.setAttr("feature_list_dense_types", featureListDenseTypesArray); - DataType[] featureListSparseTypesArray = new DataType[featureListSparseTypes.size()]; + Class[] featureListSparseTypesArray = new Class[featureListSparseTypes.size()]; for (int i = 0; i < featureListSparseTypesArray.length; ++i) { featureListSparseTypesArray[i] = featureListSparseTypes.get(i); } opBuilder.setAttr("feature_list_sparse_types", featureListSparseTypesArray); - DataType[] featureListRaggedValueTypesArray = new DataType[featureListRaggedValueTypes.size()]; + Class[] featureListRaggedValueTypesArray = new Class[featureListRaggedValueTypes.size()]; for (int i = 0; i < featureListRaggedValueTypesArray.length; ++i) { featureListRaggedValueTypesArray[i] = featureListRaggedValueTypes.get(i); } opBuilder.setAttr("feature_list_ragged_value_types", featureListRaggedValueTypesArray); - DataType[] featureListRaggedSplitTypesArray = new DataType[featureListRaggedSplitTypes.size()]; + Class[] featureListRaggedSplitTypesArray = new Class[featureListRaggedSplitTypes.size()]; for (int i = 0; i < featureListRaggedSplitTypesArray.length; ++i) { featureListRaggedSplitTypesArray[i] = featureListRaggedSplitTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleExample.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleExample.java index 2c09ba8ac4c..597e75efa9e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleExample.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleExample.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -32,6 +31,7 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * Transforms a tf.Example proto (as a string) into typed tensors. @@ -76,7 +76,7 @@ public final class ParseSingleExample extends RawOp { * @return a new instance of ParseSingleExample */ @Endpoint(describeByClass = true) - public static ParseSingleExample create(Scope scope, Operand serialized, Iterable> denseDefaults, Long numSparse, List sparseKeys, List denseKeys, List> sparseTypes, List denseShapes) { + public static ParseSingleExample create(Scope scope, Operand serialized, Iterable> denseDefaults, Long numSparse, List sparseKeys, List denseKeys, List> sparseTypes, List denseShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ParseSingleExample", scope.makeOpName("ParseSingleExample")); opBuilder.addInput(serialized.asOutput()); opBuilder.addInputList(Operands.asOutputs(denseDefaults)); @@ -92,7 +92,7 @@ public static ParseSingleExample create(Scope scope, Operand serialized denseKeysArray[i] = denseKeys.get(i); } opBuilder.setAttr("dense_keys", denseKeysArray); - DataType[] sparseTypesArray = new DataType[sparseTypes.size()]; + Class[] sparseTypesArray = new Class[sparseTypes.size()]; for (int i = 0; i < sparseTypesArray.length; ++i) { sparseTypesArray[i] = sparseTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleSequenceExample.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleSequenceExample.java index bb1e1eaa576..89105dc77ab 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleSequenceExample.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleSequenceExample.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -32,6 +31,7 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. @@ -122,7 +122,7 @@ private Options() { * @return a new instance of ParseSingleSequenceExample */ @Endpoint(describeByClass = true) - public static ParseSingleSequenceExample create(Scope scope, Operand serialized, Operand featureListDenseMissingAssumedEmpty, Iterable> contextSparseKeys, Iterable> contextDenseKeys, Iterable> featureListSparseKeys, Iterable> featureListDenseKeys, Iterable> contextDenseDefaults, Operand debugName, List> contextSparseTypes, List> featureListDenseTypes, List> featureListSparseTypes, Options... options) { + public static ParseSingleSequenceExample create(Scope scope, Operand serialized, Operand featureListDenseMissingAssumedEmpty, Iterable> contextSparseKeys, Iterable> contextDenseKeys, Iterable> featureListSparseKeys, Iterable> featureListDenseKeys, Iterable> contextDenseDefaults, Operand debugName, List> contextSparseTypes, List> featureListDenseTypes, List> featureListSparseTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ParseSingleSequenceExample", scope.makeOpName("ParseSingleSequenceExample")); opBuilder.addInput(serialized.asOutput()); opBuilder.addInput(featureListDenseMissingAssumedEmpty.asOutput()); @@ -133,17 +133,17 @@ public static ParseSingleSequenceExample create(Scope scope, Operand se opBuilder.addInputList(Operands.asOutputs(contextDenseDefaults)); opBuilder.addInput(debugName.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] contextSparseTypesArray = new DataType[contextSparseTypes.size()]; + Class[] contextSparseTypesArray = new Class[contextSparseTypes.size()]; for (int i = 0; i < contextSparseTypesArray.length; ++i) { contextSparseTypesArray[i] = contextSparseTypes.get(i); } opBuilder.setAttr("context_sparse_types", contextSparseTypesArray); - DataType[] featureListDenseTypesArray = new DataType[featureListDenseTypes.size()]; + Class[] featureListDenseTypesArray = new Class[featureListDenseTypes.size()]; for (int i = 0; i < featureListDenseTypesArray.length; ++i) { featureListDenseTypesArray[i] = featureListDenseTypes.get(i); } opBuilder.setAttr("feature_list_dense_types", featureListDenseTypesArray); - DataType[] featureListSparseTypesArray = new DataType[featureListSparseTypes.size()]; + Class[] featureListSparseTypesArray = new Class[featureListSparseTypes.size()]; for (int i = 0; i < featureListSparseTypesArray.length; ++i) { featureListSparseTypesArray[i] = featureListSparseTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseTensor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseTensor.java index bebdc1b4419..a45fa3a542f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseTensor.java @@ -17,7 +17,6 @@ package org.tensorflow.op.io; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,7 +46,7 @@ public final class ParseTensor extends RawOp implements Operand * @return a new instance of ParseTensor */ @Endpoint(describeByClass = true) - public static ParseTensor create(Scope scope, Operand serialized, DataType outType) { + public static ParseTensor create(Scope scope, Operand serialized, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("ParseTensor", scope.makeOpName("ParseTensor")); opBuilder.addInput(serialized.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PriorityQueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PriorityQueue.java index 569d4ae7eb3..b8aba81d587 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PriorityQueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PriorityQueue.java @@ -18,7 +18,6 @@ package org.tensorflow.op.io; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -95,10 +94,10 @@ private Options() { * @return a new instance of PriorityQueue */ @Endpoint(describeByClass = true) - public static PriorityQueue create(Scope scope, List> componentTypes, List shapes, Options... options) { + public static PriorityQueue create(Scope scope, List> componentTypes, List shapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("PriorityQueueV2", scope.makeOpName("PriorityQueue")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; + Class[] componentTypesArray = new Class[componentTypes.size()]; for (int i = 0; i < componentTypesArray.length; ++i) { componentTypesArray[i] = componentTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeue.java index b74fb3ad3e8..bf0cce74ece 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeue.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -75,11 +74,11 @@ private Options() { * @return a new instance of QueueDequeue */ @Endpoint(describeByClass = true) - public static QueueDequeue create(Scope scope, Operand handle, List> componentTypes, Options... options) { + public static QueueDequeue create(Scope scope, Operand handle, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QueueDequeueV2", scope.makeOpName("QueueDequeue")); opBuilder.addInput(handle.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; + Class[] componentTypesArray = new Class[componentTypes.size()]; for (int i = 0; i < componentTypesArray.length; ++i) { componentTypesArray[i] = componentTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueMany.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueMany.java index 2906fe6543c..c24cb4251f6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueMany.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueMany.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -84,12 +83,12 @@ private Options() { * @return a new instance of QueueDequeueMany */ @Endpoint(describeByClass = true) - public static QueueDequeueMany create(Scope scope, Operand handle, Operand n, List> componentTypes, Options... options) { + public static QueueDequeueMany create(Scope scope, Operand handle, Operand n, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QueueDequeueManyV2", scope.makeOpName("QueueDequeueMany")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(n.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; + Class[] componentTypesArray = new Class[componentTypes.size()]; for (int i = 0; i < componentTypesArray.length; ++i) { componentTypesArray[i] = componentTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueUpTo.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueUpTo.java index 51af788104c..d2302e77dea 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueUpTo.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueUpTo.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -88,12 +87,12 @@ private Options() { * @return a new instance of QueueDequeueUpTo */ @Endpoint(describeByClass = true) - public static QueueDequeueUpTo create(Scope scope, Operand handle, Operand n, List> componentTypes, Options... options) { + public static QueueDequeueUpTo create(Scope scope, Operand handle, Operand n, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QueueDequeueUpToV2", scope.makeOpName("QueueDequeueUpTo")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(n.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; + Class[] componentTypesArray = new Class[componentTypes.size()]; for (int i = 0; i < componentTypesArray.length; ++i) { componentTypesArray[i] = componentTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/RandomShuffleQueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/RandomShuffleQueue.java index cdde1debc95..eb2d6a4ad52 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/RandomShuffleQueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/RandomShuffleQueue.java @@ -18,7 +18,6 @@ package org.tensorflow.op.io; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -127,10 +126,10 @@ private Options() { * @return a new instance of RandomShuffleQueue */ @Endpoint(describeByClass = true) - public static RandomShuffleQueue create(Scope scope, List> componentTypes, Options... options) { + public static RandomShuffleQueue create(Scope scope, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("RandomShuffleQueueV2", scope.makeOpName("RandomShuffleQueue")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; + Class[] componentTypesArray = new Class[componentTypes.size()]; for (int i = 0; i < componentTypesArray.length; ++i) { componentTypesArray[i] = componentTypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeManySparse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeManySparse.java index 0afbdcd6edd..35b1bbfae09 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeManySparse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeManySparse.java @@ -17,7 +17,6 @@ package org.tensorflow.op.io; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -58,7 +57,7 @@ public final class SerializeManySparse extends RawOp implements * @return a new instance of SerializeManySparse */ @Endpoint(describeByClass = true) - public static SerializeManySparse create(Scope scope, Operand sparseIndices, Operand sparseValues, Operand sparseShape, DataType outType) { + public static SerializeManySparse create(Scope scope, Operand sparseIndices, Operand sparseValues, Operand sparseShape, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("SerializeManySparse", scope.makeOpName("SerializeManySparse")); opBuilder.addInput(sparseIndices.asOutput()); opBuilder.addInput(sparseValues.asOutput()); @@ -79,7 +78,7 @@ public static SerializeManySparse create(S */ @Endpoint(describeByClass = true) public static SerializeManySparse create(Scope scope, Operand sparseIndices, Operand sparseValues, Operand sparseShape) { - return create(scope, sparseIndices, sparseValues, sparseShape, TString.DTYPE); + return create(scope, sparseIndices, sparseValues, sparseShape, TString.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeSparse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeSparse.java index 2d8fc7838e4..fc4e5c09f21 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeSparse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeSparse.java @@ -17,7 +17,6 @@ package org.tensorflow.op.io; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -50,7 +49,7 @@ public final class SerializeSparse extends RawOp implements Ope * @return a new instance of SerializeSparse */ @Endpoint(describeByClass = true) - public static SerializeSparse create(Scope scope, Operand sparseIndices, Operand sparseValues, Operand sparseShape, DataType outType) { + public static SerializeSparse create(Scope scope, Operand sparseIndices, Operand sparseValues, Operand sparseShape, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("SerializeSparse", scope.makeOpName("SerializeSparse")); opBuilder.addInput(sparseIndices.asOutput()); opBuilder.addInput(sparseValues.asOutput()); @@ -71,7 +70,7 @@ public static SerializeSparse create(Scope */ @Endpoint(describeByClass = true) public static SerializeSparse create(Scope scope, Operand sparseIndices, Operand sparseValues, Operand sparseShape) { - return create(scope, sparseIndices, sparseValues, sparseShape, TString.DTYPE); + return create(scope, sparseIndices, sparseValues, sparseShape, TString.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholesky.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholesky.java index 733139f4359..40a66933f4f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholesky.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholesky.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholeskyGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholeskyGrad.java index d56f45e8abf..caa625ffe79 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholeskyGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholeskyGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixInverse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixInverse.java index 36bbee414e9..69017d8dcfc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixInverse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixInverse.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolve.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolve.java index 677b8aa0e7b..2748e30a02b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolve.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolve.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolveLs.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolveLs.java index ae4a4366b3b..ca30c69f971 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolveLs.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolveLs.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixTriangularSolve.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixTriangularSolve.java index cb1baa772fd..c747840b107 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixTriangularSolve.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixTriangularSolve.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchSelfAdjointEig.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchSelfAdjointEig.java index e589f5c49fe..e85fea495fd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchSelfAdjointEig.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchSelfAdjointEig.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code e()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/CholeskyGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/CholeskyGrad.java index b4171907853..e5ac38214dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/CholeskyGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/CholeskyGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the reverse mode backpropagated gradient of the Cholesky algorithm. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Cross.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Cross.java index 0afc8a2bf60..d223c04922e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Cross.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Cross.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the pairwise cross product. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Eig.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Eig.java index 68c33487129..e02e54defe6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Eig.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Eig.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -78,7 +77,7 @@ private Options() { * @return a new instance of Eig */ @Endpoint(describeByClass = true) - public static Eig create(Scope scope, Operand input, DataType Tout, Options... options) { + public static Eig create(Scope scope, Operand input, Class Tout, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Eig", scope.makeOpName("Eig")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Lu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Lu.java index ce37bbcbea2..fcbdbd6d80d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Lu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Lu.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -67,7 +66,7 @@ public final class Lu extends RawOp { * @return a new instance of Lu */ @Endpoint(describeByClass = true) - public static Lu create(Scope scope, Operand input, DataType outputIdxType) { + public static Lu create(Scope scope, Operand input, Class outputIdxType) { OperationBuilder opBuilder = scope.env().opBuilder("Lu", scope.makeOpName("Lu")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -85,7 +84,7 @@ public static Lu create(Scope scope, */ @Endpoint(describeByClass = true) public static Lu create(Scope scope, Operand input) { - return create(scope, input, TInt32.DTYPE); + return create(scope, input, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMul.java index 445d27e1591..0fc70a09b7d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMul.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMul.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -87,7 +86,7 @@ private Options() { * @return a new instance of QuantizedMatMul */ @Endpoint(describeByClass = true) - public static QuantizedMatMul create(Scope scope, Operand a, Operand b, Operand minA, Operand maxA, Operand minB, Operand maxB, DataType Toutput, DataType Tactivation, Options... options) { + public static QuantizedMatMul create(Scope scope, Operand a, Operand b, Operand minA, Operand maxA, Operand minB, Operand maxB, Class Toutput, Class Tactivation, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMatMul", scope.makeOpName("QuantizedMatMul")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBias.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBias.java index 52ba1e9576b..4f083e85932 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBias.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBias.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -97,7 +96,7 @@ private Options() { * @return a new instance of QuantizedMatMulWithBias */ @Endpoint(describeByClass = true) - public static QuantizedMatMulWithBias create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, DataType Toutput, Options... options) { + public static QuantizedMatMulWithBias create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Class Toutput, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMatMulWithBias", scope.makeOpName("QuantizedMatMulWithBias")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBiasAndRelu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBiasAndRelu.java index 4d7ba46f49e..c5769d336d9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBiasAndRelu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBiasAndRelu.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -98,7 +97,7 @@ private Options() { * @return a new instance of QuantizedMatMulWithBiasAndRelu */ @Endpoint(describeByClass = true) - public static QuantizedMatMulWithBiasAndRelu create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, DataType Toutput, Options... options) { + public static QuantizedMatMulWithBiasAndRelu create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Class Toutput, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMatMulWithBiasAndRelu", scope.makeOpName("QuantizedMatMulWithBiasAndRelu")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBiasAndReluAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBiasAndReluAndRequantize.java index e5639594805..a580c953e74 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBiasAndReluAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBiasAndReluAndRequantize.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -101,7 +100,7 @@ private Options() { * @return a new instance of QuantizedMatMulWithBiasAndReluAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedMatMulWithBiasAndReluAndRequantize create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Operand minFreezedOutput, Operand maxFreezedOutput, DataType Toutput, Options... options) { + public static QuantizedMatMulWithBiasAndReluAndRequantize create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Operand minFreezedOutput, Operand maxFreezedOutput, Class Toutput, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMatMulWithBiasAndReluAndRequantize", scope.makeOpName("QuantizedMatMulWithBiasAndReluAndRequantize")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixComponents.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixComponents.java index 332382cec08..3d598e9f450 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixComponents.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixComponents.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,7 +48,7 @@ public final class CSRSparseMatrixComponents extends RawOp { * @return a new instance of CSRSparseMatrixComponents */ @Endpoint(describeByClass = true) - public static CSRSparseMatrixComponents create(Scope scope, Operand csrSparseMatrix, Operand index, DataType type) { + public static CSRSparseMatrixComponents create(Scope scope, Operand csrSparseMatrix, Operand index, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("CSRSparseMatrixComponents", scope.makeOpName("CSRSparseMatrixComponents")); opBuilder.addInput(csrSparseMatrix.asOutput()); opBuilder.addInput(index.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToDense.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToDense.java index 0d8ceba40d3..6b5fe3c0e82 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToDense.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToDense.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -44,7 +43,7 @@ public final class CSRSparseMatrixToDense extends RawOp impleme * @return a new instance of CSRSparseMatrixToDense */ @Endpoint(describeByClass = true) - public static CSRSparseMatrixToDense create(Scope scope, Operand sparseInput, DataType type) { + public static CSRSparseMatrixToDense create(Scope scope, Operand sparseInput, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("CSRSparseMatrixToDense", scope.makeOpName("CSRSparseMatrixToDense")); opBuilder.addInput(sparseInput.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToSparseTensor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToSparseTensor.java index 0b5a2d41df8..3e203cac761 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToSparseTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToSparseTensor.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -45,7 +44,7 @@ public final class CSRSparseMatrixToSparseTensor extends RawOp * @return a new instance of CSRSparseMatrixToSparseTensor */ @Endpoint(describeByClass = true) - public static CSRSparseMatrixToSparseTensor create(Scope scope, Operand sparseMatrix, DataType type) { + public static CSRSparseMatrixToSparseTensor create(Scope scope, Operand sparseMatrix, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("CSRSparseMatrixToSparseTensor", scope.makeOpName("CSRSparseMatrixToSparseTensor")); opBuilder.addInput(sparseMatrix.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmax.java index 01cec0f84da..e8043662d36 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmax.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,7 +48,7 @@ public final class SparseMatrixSoftmax extends RawOp implements Operand { * @return a new instance of SparseMatrixSoftmax */ @Endpoint(describeByClass = true) - public static SparseMatrixSoftmax create(Scope scope, Operand logits, DataType type) { + public static SparseMatrixSoftmax create(Scope scope, Operand logits, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("SparseMatrixSoftmax", scope.makeOpName("SparseMatrixSoftmax")); opBuilder.addInput(logits.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmaxGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmaxGrad.java index 8ac96d72177..cea37dd4b25 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmaxGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmaxGrad.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -44,7 +43,7 @@ public final class SparseMatrixSoftmaxGrad extends RawOp implements Operand SparseMatrixSoftmaxGrad create(Scope scope, Operand softmax, Operand gradSoftmax, DataType type) { + public static SparseMatrixSoftmaxGrad create(Scope scope, Operand softmax, Operand gradSoftmax, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("SparseMatrixSoftmaxGrad", scope.makeOpName("SparseMatrixSoftmaxGrad")); opBuilder.addInput(softmax.asOutput()); opBuilder.addInput(gradSoftmax.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseCholesky.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseCholesky.java index 82f97e209d9..e70ba2b702b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseCholesky.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseCholesky.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -115,7 +114,7 @@ public final class SparseMatrixSparseCholesky extends RawOp implements Operand SparseMatrixSparseCholesky create(Scope scope, Operand input, Operand permutation, DataType type) { + public static SparseMatrixSparseCholesky create(Scope scope, Operand input, Operand permutation, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("SparseMatrixSparseCholesky", scope.makeOpName("SparseMatrixSparseCholesky")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(permutation.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseMatMul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseMatMul.java index 373ab3a01e2..26ae566f8be 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseMatMul.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseMatMul.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -163,7 +162,7 @@ private Options() { * @return a new instance of SparseMatrixSparseMatMul */ @Endpoint(describeByClass = true) - public static SparseMatrixSparseMatMul create(Scope scope, Operand a, Operand b, DataType type, Options... options) { + public static SparseMatrixSparseMatMul create(Scope scope, Operand a, Operand b, Class type, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("SparseMatrixSparseMatMul", scope.makeOpName("SparseMatrixSparseMatMul")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixTranspose.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixTranspose.java index 90e5ba4294f..05e54b0e0e6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixTranspose.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixTranspose.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -65,7 +64,7 @@ private Options() { * @return a new instance of SparseMatrixTranspose */ @Endpoint(describeByClass = true) - public static SparseMatrixTranspose create(Scope scope, Operand input, DataType type, Options... options) { + public static SparseMatrixTranspose create(Scope scope, Operand input, Class type, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("SparseMatrixTranspose", scope.makeOpName("SparseMatrixTranspose")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixZeros.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixZeros.java index 3c5bef9daed..439c4ce69b7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixZeros.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixZeros.java @@ -17,7 +17,6 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -43,7 +42,7 @@ public final class SparseMatrixZeros extends RawOp implements Operand { * @return a new instance of SparseMatrixZeros */ @Endpoint(describeByClass = true) - public static SparseMatrixZeros create(Scope scope, Operand denseShape, DataType type) { + public static SparseMatrixZeros create(Scope scope, Operand denseShape, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("SparseMatrixZeros", scope.makeOpName("SparseMatrixZeros")); opBuilder.addInput(denseShape.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Abs.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Abs.java index 7aa760a535b..7fdc1717833 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Abs.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Abs.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the absolute value of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Angle.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Angle.java index ce996785568..097a29c35ff 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Angle.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Angle.java @@ -17,7 +17,6 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -63,7 +62,7 @@ public final class Angle extends RawOp implements Operand * @return a new instance of Angle */ @Endpoint(describeByClass = true) - public static Angle create(Scope scope, Operand input, DataType Tout) { + public static Angle create(Scope scope, Operand input, Class Tout) { OperationBuilder opBuilder = scope.env().opBuilder("Angle", scope.makeOpName("Angle")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -80,7 +79,7 @@ public static Angle create(Scope scope, */ @Endpoint(describeByClass = true) public static Angle create(Scope scope, Operand input) { - return create(scope, input, TFloat32.DTYPE); + return create(scope, input, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMax.java index b45c1268cb2..42eb6778d24 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMax.java @@ -17,7 +17,6 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -63,7 +62,7 @@ public final class ArgMax extends RawOp implements Operand * @return a new instance of ArgMax */ @Endpoint(describeByClass = true) - public static ArgMax create(Scope scope, Operand input, Operand dimension, DataType outputType) { + public static ArgMax create(Scope scope, Operand input, Operand dimension, Class outputType) { OperationBuilder opBuilder = scope.env().opBuilder("ArgMax", scope.makeOpName("ArgMax")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(dimension.asOutput()); @@ -84,7 +83,7 @@ public static ArgMax */ @Endpoint(describeByClass = true) public static ArgMax create(Scope scope, Operand input, Operand dimension) { - return create(scope, input, dimension, TInt64.DTYPE); + return create(scope, input, dimension, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMin.java index 5a49adecd22..bd6b69d52e1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMin.java @@ -17,7 +17,6 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -63,7 +62,7 @@ public final class ArgMin extends RawOp implements Operand * @return a new instance of ArgMin */ @Endpoint(describeByClass = true) - public static ArgMin create(Scope scope, Operand input, Operand dimension, DataType outputType) { + public static ArgMin create(Scope scope, Operand input, Operand dimension, Class outputType) { OperationBuilder opBuilder = scope.env().opBuilder("ArgMin", scope.makeOpName("ArgMin")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(dimension.asOutput()); @@ -84,7 +83,7 @@ public static ArgMin */ @Endpoint(describeByClass = true) public static ArgMin create(Scope scope, Operand input, Operand dimension) { - return create(scope, input, dimension, TInt64.DTYPE); + return create(scope, input, dimension, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Atan2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Atan2.java index d49bc324cb3..a70043da3a3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Atan2.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Atan2.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes arctangent of `y/x` element-wise, respecting signs of the arguments. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java index 45dcd2b8e4c..ecf7aab945c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0e.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0e.java index 816933ceb12..d274265faf0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0e.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0e.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java index 148758aa5a4..b836404991e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1e.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1e.java index 3529c4a0bed..f8d40882585 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1e.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1e.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Betainc.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Betainc.java index 53ecd2ed396..e28da25745f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Betainc.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Betainc.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the regularized incomplete beta integral \\(I_x(a, b)\\). diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Bincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Bincount.java index 4217584c4fb..565a006c3b4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Bincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Bincount.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Counts the number of occurrences of each value in an integer array. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ceil.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ceil.java index c97e5b5ae5d..bda7247d508 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ceil.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ceil.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns element-wise smallest integer not less than x. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ComplexAbs.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ComplexAbs.java index 317744e519a..b899c47845a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ComplexAbs.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ComplexAbs.java @@ -17,7 +17,6 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -52,7 +51,7 @@ public final class ComplexAbs extends RawOp implements Operan * @return a new instance of ComplexAbs */ @Endpoint(describeByClass = true) - public static ComplexAbs create(Scope scope, Operand x, DataType Tout) { + public static ComplexAbs create(Scope scope, Operand x, Class Tout) { OperationBuilder opBuilder = scope.env().opBuilder("ComplexAbs", scope.makeOpName("ComplexAbs")); opBuilder.addInput(x.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -69,7 +68,7 @@ public static ComplexAbs create(Scope sc */ @Endpoint(describeByClass = true) public static ComplexAbs create(Scope scope, Operand x) { - return create(scope, x, TFloat32.DTYPE); + return create(scope, x, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/CumulativeLogsumexp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/CumulativeLogsumexp.java index 3c8e6f89870..4f199369473 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/CumulativeLogsumexp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/CumulativeLogsumexp.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the cumulative product of the tensor `x` along `axis`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java index e38d559f4ae..edbc39b49cb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Counts the number of occurrences of each value in an integer array. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Digamma.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Digamma.java index 2fd494f637c..571ade7795b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Digamma.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Digamma.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes Psi, the derivative of Lgamma (the log of the absolute value of diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erf.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erf.java index 9f86126bff4..ff1c2d343e8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erf.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erf.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the Gauss error function of `x` element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erfc.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erfc.java index b94fecf9ede..31826d91508 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erfc.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erfc.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the complementary error function of `x` element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Floor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Floor.java index ac8d7a8ffba..acd6f3a1d2b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Floor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Floor.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns element-wise largest integer not greater than x. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/FloorMod.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/FloorMod.java index 715628b57a3..ae4e933ccea 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/FloorMod.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/FloorMod.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns element-wise remainder of division. When `x < 0` xor `y < 0` is diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Greater.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Greater.java index e8ef3811fd1..a0f100e6b03 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Greater.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Greater.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the truth value of (x > y) element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/GreaterEqual.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/GreaterEqual.java index 11a83743031..fa67314267f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/GreaterEqual.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/GreaterEqual.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the truth value of (x >= y) element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igamma.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igamma.java index a4687ef3ab1..6d757ea96c6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igamma.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igamma.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the lower regularized incomplete Gamma function `P(a, x)`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IgammaGradA.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IgammaGradA.java index e57b8d886e1..9e757f8e01d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IgammaGradA.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IgammaGradA.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of `igamma(a, x)` wrt `a`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igammac.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igammac.java index 1e2cb0cf0f7..dec4671fe76 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igammac.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igammac.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the upper regularized incomplete Gamma function `Q(a, x)`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Imag.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Imag.java index 48a8a7107a3..9c4ce01655e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Imag.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Imag.java @@ -17,7 +17,6 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -59,7 +58,7 @@ public final class Imag extends RawOp implements Operand { * @return a new instance of Imag */ @Endpoint(describeByClass = true) - public static Imag create(Scope scope, Operand input, DataType Tout) { + public static Imag create(Scope scope, Operand input, Class Tout) { OperationBuilder opBuilder = scope.env().opBuilder("Imag", scope.makeOpName("Imag")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -76,7 +75,7 @@ public static Imag create(Scope scope, O */ @Endpoint(describeByClass = true) public static Imag create(Scope scope, Operand input) { - return create(scope, input, TFloat32.DTYPE); + return create(scope, input, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/InvertPermutation.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/InvertPermutation.java index de9e9c46932..859d648b2c5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/InvertPermutation.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/InvertPermutation.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the inverse permutation of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsFinite.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsFinite.java index ffe9bb99f10..b6536046d53 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsFinite.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsFinite.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns which elements of x are finite. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsInf.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsInf.java index 3d36e9fbfe9..cde3b798d16 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsInf.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsInf.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns which elements of x are Inf. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsNan.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsNan.java index b58205a005c..0ecb4836cee 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsNan.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsNan.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns which elements of x are NaN. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Less.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Less.java index 3d796296f01..cf696b5a759 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Less.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Less.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the truth value of (x < y) element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/LessEqual.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/LessEqual.java index 618cd7a9866..4b4b85a8198 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/LessEqual.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/LessEqual.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the truth value of (x <= y) element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Lgamma.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Lgamma.java index 9bc18b1b3b8..8b3070659e1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Lgamma.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Lgamma.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the log of the absolute value of `Gamma(x)` element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Maximum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Maximum.java index 9f72150c5d3..b6360e35d91 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Maximum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Maximum.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the max of x and y (i.e. x > y ? x : y) element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Minimum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Minimum.java index e11b6e484fc..7a3e003bfad 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Minimum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Minimum.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the min of x and y (i.e. x < y ? x : y) element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Mod.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Mod.java index 7ba98b81b39..0502e287fc5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Mod.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Mod.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns element-wise remainder of division. This emulates C semantics in that diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ndtri.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ndtri.java index 55dcf0c434d..eec0a0d1c9e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ndtri.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ndtri.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/NextAfter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/NextAfter.java index 8fa53306eed..7f494a874d5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/NextAfter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/NextAfter.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the next representable value of `x1` in the direction of `x2`, element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Polygamma.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Polygamma.java index d0021efc2ac..173ce8cd757 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Polygamma.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Polygamma.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the polygamma function \\(\psi^{(n)}(x)\\). diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/PopulationCount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/PopulationCount.java index bc721897426..68c9ebf055b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/PopulationCount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/PopulationCount.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes element-wise population count (a.k.a. popcount, bitsum, bitcount). diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedAdd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedAdd.java index ccd6d2e5f98..2c956969e8d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedAdd.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedAdd.java @@ -17,7 +17,6 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -51,7 +50,7 @@ public final class QuantizedAdd extends RawOp { * @return a new instance of QuantizedAdd */ @Endpoint(describeByClass = true) - public static QuantizedAdd create(Scope scope, Operand x, Operand y, Operand minX, Operand maxX, Operand minY, Operand maxY, DataType Toutput) { + public static QuantizedAdd create(Scope scope, Operand x, Operand y, Operand minX, Operand maxX, Operand minY, Operand maxY, Class Toutput) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedAdd", scope.makeOpName("QuantizedAdd")); opBuilder.addInput(x.asOutput()); opBuilder.addInput(y.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedMul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedMul.java index e16dc423a63..bd666426bd1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedMul.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedMul.java @@ -17,7 +17,6 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -51,7 +50,7 @@ public final class QuantizedMul extends RawOp { * @return a new instance of QuantizedMul */ @Endpoint(describeByClass = true) - public static QuantizedMul create(Scope scope, Operand x, Operand y, Operand minX, Operand maxX, Operand minY, Operand maxY, DataType Toutput) { + public static QuantizedMul create(Scope scope, Operand x, Operand y, Operand minX, Operand maxX, Operand minY, Operand maxY, Class Toutput) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMul", scope.makeOpName("QuantizedMul")); opBuilder.addInput(x.asOutput()); opBuilder.addInput(y.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Real.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Real.java index e113597afd8..bd00e9a4ae6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Real.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Real.java @@ -17,7 +17,6 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -59,7 +58,7 @@ public final class Real extends RawOp implements Operand { * @return a new instance of Real */ @Endpoint(describeByClass = true) - public static Real create(Scope scope, Operand input, DataType Tout) { + public static Real create(Scope scope, Operand input, Class Tout) { OperationBuilder opBuilder = scope.env().opBuilder("Real", scope.makeOpName("Real")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -76,7 +75,7 @@ public static Real create(Scope scope, O */ @Endpoint(describeByClass = true) public static Real create(Scope scope, Operand input) { - return create(scope, input, TFloat32.DTYPE); + return create(scope, input, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/RequantizePerChannel.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/RequantizePerChannel.java index 53a6201151e..515da442497 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/RequantizePerChannel.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/RequantizePerChannel.java @@ -17,7 +17,6 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,7 +48,7 @@ public final class RequantizePerChannel extends RawOp { * @return a new instance of RequantizePerChannel */ @Endpoint(describeByClass = true) - public static RequantizePerChannel create(Scope scope, Operand input, Operand inputMin, Operand inputMax, Operand requestedOutputMin, Operand requestedOutputMax, DataType outType) { + public static RequantizePerChannel create(Scope scope, Operand input, Operand inputMin, Operand inputMax, Operand requestedOutputMin, Operand requestedOutputMax, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("RequantizePerChannel", scope.makeOpName("RequantizePerChannel")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(inputMin.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Rint.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Rint.java index b7b152248a8..849508eba02 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Rint.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Rint.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns element-wise integer closest to x. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMax.java index fdd40054420..e044212668b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMax.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the maximum along segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMin.java index b7ab590e976..b628c6f5eb1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMin.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the minimum along segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SobolSample.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SobolSample.java index 131cb7ed792..e7c9c43d20d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SobolSample.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SobolSample.java @@ -17,7 +17,6 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Generates points from the Sobol sequence. @@ -54,7 +52,7 @@ public final class SobolSample extends RawOp implements Opera * @return a new instance of SobolSample */ @Endpoint(describeByClass = true) - public static SobolSample create(Scope scope, Operand dim, Operand numResults, Operand skip, DataType dtype) { + public static SobolSample create(Scope scope, Operand dim, Operand numResults, Operand skip, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("SobolSample", scope.makeOpName("SobolSample")); opBuilder.addInput(dim.asOutput()); opBuilder.addInput(numResults.asOutput()); @@ -77,7 +75,7 @@ public static SobolSample create(Scope scope, Operand create(Scope scope, Operand dim, Operand numResults, Operand skip) { - return create(scope, dim, numResults, skip, TFloat32.DTYPE); + return create(scope, dim, numResults, skip, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Softplus.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Softplus.java index 79adcf30d76..5065f948cc1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Softplus.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Softplus.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softplus: `log(exp(features) + 1)`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SoftplusGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SoftplusGrad.java index 90174757bc3..100243c60e4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SoftplusGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SoftplusGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softplus gradients for a softplus operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/TruncateMod.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/TruncateMod.java index 48c92574eb4..0fb4397c881 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/TruncateMod.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/TruncateMod.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns element-wise remainder of division. This emulates C semantics in that diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMax.java index 582113e3320..733e3105846 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMax.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the maximum along segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMin.java index 397ac5706a0..408995715a0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMin.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the minimum along segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Zeta.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Zeta.java index 0f3cee188fc..576af0a1e6e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Zeta.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Zeta.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the Hurwitz zeta function \\(\zeta(x, q)\\). diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/erfinv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/erfinv.java index de2a4482f65..e74d429e129 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/erfinv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/erfinv.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java index 8d2184a49cb..c49c2d69350 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java index d8f9621a36c..fb7f29046de 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java index eaae243f83f..945b617b1dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java index c57ae64e233..d1a014d4959 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java index 1858d25fe3d..0cb60fa3d88 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java index e4a5cc23efd..b01a57da175 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java index 9228d1b6145..619147cbb6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java index 0461416b808..51eb6e377f2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Dawsn.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Dawsn.java index 74388434149..529514c9bf1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Dawsn.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Dawsn.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Expint.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Expint.java index b36c55fdeb6..099cd544ada 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Expint.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Expint.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelCos.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelCos.java index bb9a9f47e78..5d5ce736e42 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelCos.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelCos.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelSin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelSin.java index 36681c87678..5914fe4da76 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelSin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelSin.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Spence.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Spence.java index ed613a28b1a..faf21f53510 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Spence.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Spence.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool.java index 527d1a49713..a27f1aa70c7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs average pooling on the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3d.java index 87467c8a982..177227c0a52 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3d.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs 3D average pooling on the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3dGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3dGrad.java index 3e34c87f9b6..c23a0c5360c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3dGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3dGrad.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients of average pooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPoolGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPoolGrad.java index 7ec252eaac0..ec149807546 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPoolGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPoolGrad.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients of the average pooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTM.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTM.java index 79de4f2f88c..c40f2eeea4b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTM.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTM.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the LSTM cell forward propagation for all the time steps. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTMGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTMGrad.java index 4e49f23cb6e..f7a17b1bf76 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTMGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTMGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the LSTM cell backward propagation for the entire time sequence. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2d.java index 79f2022d807..70a72bc78da 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2d.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes a 2-D convolution given 4-D `input` and `filter` tensors. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropFilter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropFilter.java index 9c145dc3d60..709d59dee28 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropFilter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropFilter.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradients of convolution with respect to the filter. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropInput.java index 492c4f2aebb..b725823e8db 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropInput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropInput.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradients of convolution with respect to the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3d.java index 94ae4ffd2b5..0720ef0dce5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3d.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes a 3-D convolution given 5-D `input` and `filter` tensors. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropFilter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropFilter.java index 0d73e491717..1de81e9f781 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropFilter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropFilter.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradients of 3-D convolution with respect to the filter. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropInput.java index 8b153890811..1f71f6b64f3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropInput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropInput.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradients of 3-D convolution with respect to the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcBeamSearchDecoder.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcBeamSearchDecoder.java index 96f179641e7..fadc3985f38 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcBeamSearchDecoder.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcBeamSearchDecoder.java @@ -30,7 +30,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs beam search decoding on the logits given in input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcGreedyDecoder.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcGreedyDecoder.java index b8b35b8ceaa..2c985fb4e03 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcGreedyDecoder.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcGreedyDecoder.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs greedy decoding on the logits given in inputs. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcLoss.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcLoss.java index ceb60a4baf7..44acdfbad19 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcLoss.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcLoss.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Calculates the CTC Loss (log probability) for each batch entry. Also calculates diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNN.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNN.java index 7c179d7e578..44495c8b7a7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNN.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNN.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * A RNN backed by cuDNN. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNBackprop.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNBackprop.java index 3719c186716..9e51b12fee7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNBackprop.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNBackprop.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Backprop step of CudnnRNNV3. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNCanonicalToParams.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNCanonicalToParams.java index 155cfd2c0a4..38df41c251a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNCanonicalToParams.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNCanonicalToParams.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Converts CudnnRNN params from canonical form to usable form. It supports the projection in LSTM. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNParamsToCanonical.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNParamsToCanonical.java index ea575f2d7a1..a6fcda3786e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNParamsToCanonical.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNParamsToCanonical.java @@ -29,7 +29,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Retrieves CudnnRNN params in canonical form. It supports the projection in LSTM. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRnnParamsSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRnnParamsSize.java index 94421d0aa6e..dd3f4ba6392 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRnnParamsSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRnnParamsSize.java @@ -17,7 +17,6 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes size of weights that can be used by a Cudnn RNN model. @@ -146,7 +144,7 @@ private Options() { * @return a new instance of CudnnRnnParamsSize */ @Endpoint(describeByClass = true) - public static CudnnRnnParamsSize create(Scope scope, Operand numLayers, Operand numUnits, Operand inputSize, DataType T, DataType S, Options... options) { + public static CudnnRnnParamsSize create(Scope scope, Operand numLayers, Operand numUnits, Operand inputSize, Class T, Class S, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("CudnnRNNParamsSize", scope.makeOpName("CudnnRnnParamsSize")); opBuilder.addInput(numLayers.asOutput()); opBuilder.addInput(numUnits.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatDimMap.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatDimMap.java index 6b1ff40761b..3eb8a46d310 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatDimMap.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatDimMap.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the dimension index in the destination data format given the one in diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatVecPermute.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatVecPermute.java index ba218c1923f..0ba1ddeb4ba 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatVecPermute.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatVecPermute.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the permuted vector/tensor in the destination data format given the diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNative.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNative.java index 6549d8b9dfc..2aff519c733 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNative.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNative.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropFilter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropFilter.java index fadfdacc823..ac51c06e737 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropFilter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropFilter.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradients of depthwise convolution with respect to the filter. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropInput.java index d2e3e733a01..26db9841d7d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropInput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropInput.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradients of depthwise convolution with respect to the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2d.java index c7135b20361..e6bd8cc5b6a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2d.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropFilter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropFilter.java index 9254ff8c285..04705f53024 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropFilter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropFilter.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of morphological 2-D dilation with respect to the filter. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropInput.java index 525e06182c5..1636545164c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropInput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropInput.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of morphological 2-D dilation with respect to the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Elu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Elu.java index daeef97895b..faa70f04b98 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Elu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Elu.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/EluGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/EluGrad.java index 664475879a0..9ead32623ab 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/EluGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/EluGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients for the exponential linear (Elu) operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPool.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPool.java index 101ea21ec8c..93a477ef25e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPool.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPool.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs fractional average pooling on the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPoolGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPoolGrad.java index 03ae0136311..7a4258d18c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPoolGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPoolGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradient of the FractionalAvgPool function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPool.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPool.java index a621e037740..9ca640c804b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPool.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPool.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs fractional max pooling on the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPoolGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPoolGrad.java index 1b2bcc62dbf..952c8ce286a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPoolGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPoolGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradient of the FractionalMaxPool function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNorm.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNorm.java index 04e8d71e1e5..0d2b3a6dfe9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNorm.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNorm.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Batch normalization. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNormGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNormGrad.java index 4d2ecc74a4f..8e9f616adcb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNormGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNormGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Gradient for batch normalization. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedPadConv2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedPadConv2d.java index 2d3de921f0b..89713030f7b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedPadConv2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedPadConv2d.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs a padding as a preprocess during a convolution. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedResizeAndPadConv2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedResizeAndPadConv2d.java index 02f82242c7a..b89af628392 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedResizeAndPadConv2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedResizeAndPadConv2d.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs a resize and padding as a preprocess during a convolution. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCell.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCell.java index 446f43cfb0c..fad66727252 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCell.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCell.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the GRU cell forward propagation for 1 time step. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCellGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCellGrad.java index 4a6b72b8e4c..69c4295d5c4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCellGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCellGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the GRU cell back-propagation for 1 time step. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/InTopK.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/InTopK.java index d38388a3fd5..ea2a2652b83 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/InTopK.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/InTopK.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Says whether the targets are in the top `K` predictions. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/L2Loss.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/L2Loss.java index b38d9e99d96..eb38a9baa65 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/L2Loss.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/L2Loss.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * L2 Loss. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCell.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCell.java index ab5d1b8aee2..160d2bc89be 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCell.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCell.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the LSTM cell forward propagation for 1 time step. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCellGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCellGrad.java index 3635ec433d6..3a36077d6ca 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCellGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCellGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the LSTM cell backward propagation for 1 timestep. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java index 8ca3f540cad..96f44eacb5a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes rectified linear: `max(features, features * alpha)`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalization.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalization.java index a560c9cfb58..e7d72196b4a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalization.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalization.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Local Response Normalization. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalizationGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalizationGrad.java index 08f10c12f26..d347bec7237 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalizationGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalizationGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Gradients for Local Response Normalization. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LogSoftmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LogSoftmax.java index bd1bff467a1..e9e123e516d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LogSoftmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LogSoftmax.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes log softmax activations. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3d.java index 61bf43372cc..320f38a084e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3d.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs 3D max pooling on the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGrad.java index 2a4fe5e61f9..7f2938e8827 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients of 3D max pooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGradGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGradGrad.java index 4243e1e5143..ef4c546c41e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGradGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGradGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes second-order gradients of the maxpooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGrad.java index 6f5cbd2ce64..e391585b657 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients of the maxpooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGrad.java index 4f9255e9286..2653f1d51c2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes second-order gradients of the maxpooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGradWithArgmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGradWithArgmax.java index 6258024daa9..a15222e7c4f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGradWithArgmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGradWithArgmax.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes second-order gradients of the maxpooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradWithArgmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradWithArgmax.java index 329c88b4d33..c51275f68f2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradWithArgmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradWithArgmax.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients of the maxpooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolWithArgmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolWithArgmax.java index 63f92de1fea..841e598ee3a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolWithArgmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolWithArgmax.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs max pooling on the input and outputs both max values and indices. @@ -83,7 +81,7 @@ private Options() { * @return a new instance of MaxPoolWithArgmax */ @Endpoint(describeByClass = true) - public static MaxPoolWithArgmax create(Scope scope, Operand input, List ksize, List strides, DataType Targmax, String padding, Options... options) { + public static MaxPoolWithArgmax create(Scope scope, Operand input, List ksize, List strides, Class Targmax, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MaxPoolWithArgmax", scope.makeOpName("MaxPoolWithArgmax")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -123,7 +121,7 @@ public static MaxPoolWithArgmax cre */ @Endpoint(describeByClass = true) public static MaxPoolWithArgmax create(Scope scope, Operand input, List ksize, List strides, String padding, Options... options) { - return create(scope, input, ksize, strides, TInt64.DTYPE, padding, options); + return create(scope, input, ksize, strides, TInt64.class, padding, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/NthElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/NthElement.java index 340cd4b98f7..35f827f5d01 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/NthElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/NthElement.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Finds values of the `n`-th order statistic for the last dimension. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBatchNormWithGlobalNormalization.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBatchNormWithGlobalNormalization.java index 41cc16017fc..57b2af3a66d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBatchNormWithGlobalNormalization.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBatchNormWithGlobalNormalization.java @@ -17,7 +17,6 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -73,7 +72,7 @@ public final class QuantizedBatchNormWithGlobalNormalization ex * @return a new instance of QuantizedBatchNormWithGlobalNormalization */ @Endpoint(describeByClass = true) - public static QuantizedBatchNormWithGlobalNormalization create(Scope scope, Operand t, Operand tMin, Operand tMax, Operand m, Operand mMin, Operand mMax, Operand v, Operand vMin, Operand vMax, Operand beta, Operand betaMin, Operand betaMax, Operand gamma, Operand gammaMin, Operand gammaMax, DataType outType, Float varianceEpsilon, Boolean scaleAfterNormalization) { + public static QuantizedBatchNormWithGlobalNormalization create(Scope scope, Operand t, Operand tMin, Operand tMax, Operand m, Operand mMin, Operand mMax, Operand v, Operand vMin, Operand vMax, Operand beta, Operand betaMin, Operand betaMax, Operand gamma, Operand gammaMin, Operand gammaMax, Class outType, Float varianceEpsilon, Boolean scaleAfterNormalization) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedBatchNormWithGlobalNormalization", scope.makeOpName("QuantizedBatchNormWithGlobalNormalization")); opBuilder.addInput(t.asOutput()); opBuilder.addInput(tMin.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBiasAdd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBiasAdd.java index 62a8002ad49..b39db777a05 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBiasAdd.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBiasAdd.java @@ -17,7 +17,6 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -53,7 +52,7 @@ public final class QuantizedBiasAdd extends RawOp { * @return a new instance of QuantizedBiasAdd */ @Endpoint(describeByClass = true) - public static QuantizedBiasAdd create(Scope scope, Operand input, Operand bias, Operand minInput, Operand maxInput, Operand minBias, Operand maxBias, DataType outType) { + public static QuantizedBiasAdd create(Scope scope, Operand input, Operand bias, Operand minInput, Operand maxInput, Operand minBias, Operand maxBias, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedBiasAdd", scope.makeOpName("QuantizedBiasAdd")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(bias.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRelu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRelu.java index 961a4245155..362a64d595e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRelu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRelu.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -80,7 +79,7 @@ private Options() { * @return a new instance of QuantizedConv2DAndRelu */ @Endpoint(describeByClass = true) - public static QuantizedConv2DAndRelu create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DAndRelu create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DAndRelu", scope.makeOpName("QuantizedConv2DAndRelu")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndReluAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndReluAndRequantize.java index 8b231edc09b..f55cae7b48f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndReluAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndReluAndRequantize.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -82,7 +81,7 @@ private Options() { * @return a new instance of QuantizedConv2DAndReluAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedConv2DAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DAndReluAndRequantize", scope.makeOpName("QuantizedConv2DAndReluAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRequantize.java index 8c9f1247d97..30f303ae8f9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRequantize.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -82,7 +81,7 @@ private Options() { * @return a new instance of QuantizedConv2DAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedConv2DAndRequantize create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DAndRequantize create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DAndRequantize", scope.makeOpName("QuantizedConv2DAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DPerChannel.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DPerChannel.java index efc0132e2a8..23019817b6d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DPerChannel.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DPerChannel.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -73,7 +72,7 @@ private Options() { * @return a new instance of QuantizedConv2DPerChannel */ @Endpoint(describeByClass = true) - public static QuantizedConv2DPerChannel create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DPerChannel create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DPerChannel", scope.makeOpName("QuantizedConv2DPerChannel")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBias.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBias.java index bc78b5b4ec1..a0e9fab70ab 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBias.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBias.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -81,7 +80,7 @@ private Options() { * @return a new instance of QuantizedConv2DWithBias */ @Endpoint(describeByClass = true) - public static QuantizedConv2DWithBias create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBias create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBias", scope.makeOpName("QuantizedConv2DWithBias")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndRelu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndRelu.java index 719150ee139..b0288a5adf0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndRelu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndRelu.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -81,7 +80,7 @@ private Options() { * @return a new instance of QuantizedConv2DWithBiasAndRelu */ @Endpoint(describeByClass = true) - public static QuantizedConv2DWithBiasAndRelu create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBiasAndRelu create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBiasAndRelu", scope.makeOpName("QuantizedConv2DWithBiasAndRelu")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndReluAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndReluAndRequantize.java index a61cab41d5e..c06f4ebe73c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndReluAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndReluAndRequantize.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -83,7 +82,7 @@ private Options() { * @return a new instance of QuantizedConv2DWithBiasAndReluAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedConv2DWithBiasAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBiasAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBiasAndReluAndRequantize", scope.makeOpName("QuantizedConv2DWithBiasAndReluAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndRequantize.java index af89fd48962..81d1f4497a0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndRequantize.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -83,7 +82,7 @@ private Options() { * @return a new instance of QuantizedConv2DWithBiasAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedConv2DWithBiasAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBiasAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBiasAndRequantize", scope.makeOpName("QuantizedConv2DWithBiasAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSignedSumAndReluAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSignedSumAndReluAndRequantize.java index cfa7770cf25..3c60a3f3a1a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSignedSumAndReluAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSignedSumAndReluAndRequantize.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -86,7 +85,7 @@ private Options() { * @return a new instance of QuantizedConv2DWithBiasSignedSumAndReluAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedConv2DWithBiasSignedSumAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Operand summand, Operand minSummand, Operand maxSummand, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBiasSignedSumAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Operand summand, Operand minSummand, Operand maxSummand, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize", scope.makeOpName("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSumAndRelu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSumAndRelu.java index 1096fa12578..0788d608ac1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSumAndRelu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSumAndRelu.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -82,7 +81,7 @@ private Options() { * @return a new instance of QuantizedConv2DWithBiasSumAndRelu */ @Endpoint(describeByClass = true) - public static QuantizedConv2DWithBiasSumAndRelu create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand summand, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBiasSumAndRelu create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand summand, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBiasSumAndRelu", scope.makeOpName("QuantizedConv2DWithBiasSumAndRelu")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSumAndReluAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSumAndReluAndRequantize.java index 55adb0e016d..91c903bad91 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSumAndReluAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSumAndReluAndRequantize.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -86,7 +85,7 @@ private Options() { * @return a new instance of QuantizedConv2DWithBiasSumAndReluAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedConv2DWithBiasSumAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Operand summand, Operand minSummand, Operand maxSummand, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBiasSumAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Operand summand, Operand minSummand, Operand maxSummand, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBiasSumAndReluAndRequantize", scope.makeOpName("QuantizedConv2DWithBiasSumAndReluAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2d.java index 39913922bdb..ffd66fc3f23 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2d.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -84,7 +83,7 @@ private Options() { * @return a new instance of QuantizedConv2d */ @Endpoint(describeByClass = true) - public static QuantizedConv2d create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2d create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2D", scope.makeOpName("QuantizedConv2d")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2D.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2D.java index f64d9a7efbb..1e0a8c38e0f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2D.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2D.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -73,7 +72,7 @@ private Options() { * @return a new instance of QuantizedDepthwiseConv2D */ @Endpoint(describeByClass = true) - public static QuantizedDepthwiseConv2D create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedDepthwiseConv2D create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedDepthwiseConv2D", scope.makeOpName("QuantizedDepthwiseConv2D")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBias.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBias.java index be5e2bb8657..b0d29c64e39 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBias.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBias.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -74,7 +73,7 @@ private Options() { * @return a new instance of QuantizedDepthwiseConv2DWithBias */ @Endpoint(describeByClass = true) - public static QuantizedDepthwiseConv2DWithBias create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedDepthwiseConv2DWithBias create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedDepthwiseConv2DWithBias", scope.makeOpName("QuantizedDepthwiseConv2DWithBias")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndRelu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndRelu.java index 8abd12b865f..e88bddc4f83 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndRelu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndRelu.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -83,7 +82,7 @@ private Options() { * @return a new instance of QuantizedDepthwiseConv2DWithBiasAndRelu */ @Endpoint(describeByClass = true) - public static QuantizedDepthwiseConv2DWithBiasAndRelu create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedDepthwiseConv2DWithBiasAndRelu create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedDepthwiseConv2DWithBiasAndRelu", scope.makeOpName("QuantizedDepthwiseConv2DWithBiasAndRelu")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize.java index 78c8048266e..3430d0a6e2f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize.java @@ -18,7 +18,6 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -85,7 +84,7 @@ private Options() { * @return a new instance of QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, DataType outType, List strides, String padding, Options... options) { + public static QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", scope.makeOpName("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedRelu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedRelu.java index 308e14e9512..8881d2eb824 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedRelu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedRelu.java @@ -17,7 +17,6 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -48,7 +47,7 @@ public final class QuantizedRelu extends RawOp { * @return a new instance of QuantizedRelu */ @Endpoint(describeByClass = true) - public static QuantizedRelu create(Scope scope, Operand features, Operand minFeatures, Operand maxFeatures, DataType outType) { + public static QuantizedRelu create(Scope scope, Operand features, Operand minFeatures, Operand maxFeatures, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedRelu", scope.makeOpName("QuantizedRelu")); opBuilder.addInput(features.asOutput()); opBuilder.addInput(minFeatures.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedRelu6.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedRelu6.java index 46f04fd5722..2153c3561df 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedRelu6.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedRelu6.java @@ -17,7 +17,6 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -48,7 +47,7 @@ public final class QuantizedRelu6 extends RawOp { * @return a new instance of QuantizedRelu6 */ @Endpoint(describeByClass = true) - public static QuantizedRelu6 create(Scope scope, Operand features, Operand minFeatures, Operand maxFeatures, DataType outType) { + public static QuantizedRelu6 create(Scope scope, Operand features, Operand minFeatures, Operand maxFeatures, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedRelu6", scope.makeOpName("QuantizedRelu6")); opBuilder.addInput(features.asOutput()); opBuilder.addInput(minFeatures.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedReluX.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedReluX.java index e47a6a5c043..9db60f2174c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedReluX.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedReluX.java @@ -17,7 +17,6 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,7 +48,7 @@ public final class QuantizedReluX extends RawOp { * @return a new instance of QuantizedReluX */ @Endpoint(describeByClass = true) - public static QuantizedReluX create(Scope scope, Operand features, Operand maxValue, Operand minFeatures, Operand maxFeatures, DataType outType) { + public static QuantizedReluX create(Scope scope, Operand features, Operand maxValue, Operand minFeatures, Operand maxFeatures, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedReluX", scope.makeOpName("QuantizedReluX")); opBuilder.addInput(features.asOutput()); opBuilder.addInput(maxValue.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6.java index e4e674b7e35..f227e995b2e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes rectified linear 6: `min(max(features, 0), 6)`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6Grad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6Grad.java index cf4edc9debf..043f3b5a5d7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6Grad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6Grad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes rectified linear 6 gradients for a Relu6 operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/ReluGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/ReluGrad.java index d1c7cf7d44c..b41caedda64 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/ReluGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/ReluGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes rectified linear gradients for a Relu operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Selu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Selu.java index 2ce8484b299..355222c25fa 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Selu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Selu.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SeluGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SeluGrad.java index 514ad00b38d..6bf70f4038d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SeluGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SeluGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients for the scaled exponential linear (Selu) operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softmax.java index d9971b6667d..093195aea28 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softmax.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softmax activations. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softsign.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softsign.java index e4276c679a3..0b2e3a639e9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softsign.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softsign.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softsign: `features / (abs(features) + 1)`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftsignGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftsignGrad.java index 220763e57a2..31367f06df5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftsignGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftsignGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softsign gradients for a softsign operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/TopK.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/TopK.java index acd4ba679f7..a8bfdbea715 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/TopK.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/TopK.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Finds values and indices of the `k` largest elements for the last dimension. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java index 4c23683d9ef..1a5011c713a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softmax cross entropy cost and gradients to backpropagate. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java index e7cb45231de..c8dcac1a84a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softmax cross entropy cost and gradients to backpropagate. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Dequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Dequantize.java index 760dd9fe913..6c9484cab55 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Dequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Dequantize.java @@ -17,7 +17,6 @@ package org.tensorflow.op.quantization; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -137,7 +136,7 @@ private Options() { * @return a new instance of Dequantize */ @Endpoint(describeByClass = true) - public static Dequantize create(Scope scope, Operand input, Operand minRange, Operand maxRange, DataType dtype, Options... options) { + public static Dequantize create(Scope scope, Operand input, Operand minRange, Operand maxRange, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Dequantize", scope.makeOpName("Dequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(minRange.asOutput()); @@ -172,7 +171,7 @@ public static Dequantize create(Scope sc */ @Endpoint(describeByClass = true) public static Dequantize create(Scope scope, Operand input, Operand minRange, Operand maxRange, Options... options) { - return create(scope, input, minRange, maxRange, TFloat32.DTYPE, options); + return create(scope, input, minRange, maxRange, TFloat32.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Quantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Quantize.java index 37f58cfa512..61370659157 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Quantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Quantize.java @@ -17,7 +17,6 @@ package org.tensorflow.op.quantization; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -220,7 +219,7 @@ private Options() { * @return a new instance of Quantize */ @Endpoint(describeByClass = true) - public static Quantize create(Scope scope, Operand input, Operand minRange, Operand maxRange, DataType T, Options... options) { + public static Quantize create(Scope scope, Operand input, Operand minRange, Operand maxRange, Class T, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizeV2", scope.makeOpName("Quantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(minRange.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeAndDequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeAndDequantize.java index fd75b330e41..c14d9f451d7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeAndDequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeAndDequantize.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Quantizes then dequantizes a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeDownAndShrinkRange.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeDownAndShrinkRange.java index 362375b40e2..04f13bd833d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeDownAndShrinkRange.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeDownAndShrinkRange.java @@ -17,7 +17,6 @@ package org.tensorflow.op.quantization; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -71,7 +70,7 @@ public final class QuantizeDownAndShrinkRange extends RawOp { * @return a new instance of QuantizeDownAndShrinkRange */ @Endpoint(describeByClass = true) - public static QuantizeDownAndShrinkRange create(Scope scope, Operand input, Operand inputMin, Operand inputMax, DataType outType) { + public static QuantizeDownAndShrinkRange create(Scope scope, Operand input, Operand inputMin, Operand inputMax, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizeDownAndShrinkRange", scope.makeOpName("QuantizeDownAndShrinkRange")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(inputMin.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizedMatMulWithBiasAndDequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizedMatMulWithBiasAndDequantize.java index baa20635c51..114d40bd058 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizedMatMulWithBiasAndDequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizedMatMulWithBiasAndDequantize.java @@ -17,7 +17,6 @@ package org.tensorflow.op.quantization; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -90,7 +89,7 @@ private Options() { * @return a new instance of QuantizedMatMulWithBiasAndDequantize */ @Endpoint(describeByClass = true) - public static QuantizedMatMulWithBiasAndDequantize create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Operand minFreezedOutput, Operand maxFreezedOutput, DataType Toutput, Options... options) { + public static QuantizedMatMulWithBiasAndDequantize create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Operand minFreezedOutput, Operand maxFreezedOutput, Class Toutput, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMatMulWithBiasAndDequantize", scope.makeOpName("QuantizedMatMulWithBiasAndDequantize")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizedMatMulWithBiasAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizedMatMulWithBiasAndRequantize.java index 950221b1b94..d553cd17e60 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizedMatMulWithBiasAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizedMatMulWithBiasAndRequantize.java @@ -17,7 +17,6 @@ package org.tensorflow.op.quantization; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -89,7 +88,7 @@ private Options() { * @return a new instance of QuantizedMatMulWithBiasAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedMatMulWithBiasAndRequantize create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Operand minFreezedOutput, Operand maxFreezedOutput, DataType Toutput, Options... options) { + public static QuantizedMatMulWithBiasAndRequantize create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Operand minFreezedOutput, Operand maxFreezedOutput, Class Toutput, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMatMulWithBiasAndRequantize", scope.makeOpName("QuantizedMatMulWithBiasAndRequantize")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Requantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Requantize.java index 5df8ca0b622..1ee6077abe1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Requantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Requantize.java @@ -17,7 +17,6 @@ package org.tensorflow.op.quantization; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -58,7 +57,7 @@ public final class Requantize extends RawOp { * @return a new instance of Requantize */ @Endpoint(describeByClass = true) - public static Requantize create(Scope scope, Operand input, Operand inputMin, Operand inputMax, Operand requestedOutputMin, Operand requestedOutputMax, DataType outType) { + public static Requantize create(Scope scope, Operand input, Operand inputMin, Operand inputMax, Operand requestedOutputMin, Operand requestedOutputMax, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("Requantize", scope.makeOpName("Requantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(inputMin.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java index 1e0224aa9ef..b5e54a6a822 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Counts the number of occurrences of each value in an integer array. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java index 4829e49488b..5148b8f1f00 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs sparse-output bin counting for a ragged tensor input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java index 9ea32878257..4a21e8cdfe4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java @@ -17,7 +17,6 @@ package org.tensorflow.op.ragged; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -62,7 +61,7 @@ public final class RaggedCross extends RawOp * @return a new instance of RaggedCross */ @Endpoint(describeByClass = true) - public static RaggedCross create(Scope scope, Iterable> raggedValues, Iterable> raggedRowSplits, Iterable> sparseIndices, Iterable> sparseValues, Iterable> sparseShape, Iterable> denseInputs, String inputOrder, Boolean hashedOutput, Long numBuckets, Long hashKey, DataType outValuesType, DataType outRowSplitsType) { + public static RaggedCross create(Scope scope, Iterable> raggedValues, Iterable> raggedRowSplits, Iterable> sparseIndices, Iterable> sparseValues, Iterable> sparseShape, Iterable> denseInputs, String inputOrder, Boolean hashedOutput, Long numBuckets, Long hashKey, Class outValuesType, Class outRowSplitsType) { OperationBuilder opBuilder = scope.env().opBuilder("RaggedCross", scope.makeOpName("RaggedCross")); opBuilder.addInputList(Operands.asOutputs(raggedValues)); opBuilder.addInputList(Operands.asOutputs(raggedRowSplits)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedRange.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedRange.java index 9d8f5594fd9..ea23e612ab2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedRange.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedRange.java @@ -17,7 +17,6 @@ package org.tensorflow.op.ragged; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns a `RaggedTensor` containing the specified sequences of numbers. @@ -64,7 +62,7 @@ public final class RaggedRange extends Raw * @return a new instance of RaggedRange */ @Endpoint(describeByClass = true) - public static RaggedRange create(Scope scope, Operand starts, Operand limits, Operand deltas, DataType Tsplits) { + public static RaggedRange create(Scope scope, Operand starts, Operand limits, Operand deltas, Class Tsplits) { OperationBuilder opBuilder = scope.env().opBuilder("RaggedRange", scope.makeOpName("RaggedRange")); opBuilder.addInput(starts.asOutput()); opBuilder.addInput(limits.asOutput()); @@ -85,7 +83,7 @@ public static RaggedRange create(Sc */ @Endpoint(describeByClass = true) public static RaggedRange create(Scope scope, Operand starts, Operand limits, Operand deltas) { - return create(scope, starts, limits, deltas, TInt64.DTYPE); + return create(scope, starts, limits, deltas, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedTensorFromVariant.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedTensorFromVariant.java index c9dcfe54bda..cd447134359 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedTensorFromVariant.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedTensorFromVariant.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -67,7 +66,7 @@ public final class RaggedTensorFromVariant e * @return a new instance of RaggedTensorFromVariant */ @Endpoint(describeByClass = true) - public static RaggedTensorFromVariant create(Scope scope, Operand encodedRagged, Long inputRaggedRank, Long outputRaggedRank, DataType Tvalues, DataType Tsplits) { + public static RaggedTensorFromVariant create(Scope scope, Operand encodedRagged, Long inputRaggedRank, Long outputRaggedRank, Class Tvalues, Class Tsplits) { OperationBuilder opBuilder = scope.env().opBuilder("RaggedTensorFromVariant", scope.makeOpName("RaggedTensorFromVariant")); opBuilder.addInput(encodedRagged.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -91,8 +90,8 @@ public static RaggedTensorFromVariant * @return a new instance of RaggedTensorFromVariant */ @Endpoint(describeByClass = true) - public static RaggedTensorFromVariant create(Scope scope, Operand encodedRagged, Long inputRaggedRank, Long outputRaggedRank, DataType Tvalues) { - return create(scope, encodedRagged, inputRaggedRank, outputRaggedRank, Tvalues, TInt64.DTYPE); + public static RaggedTensorFromVariant create(Scope scope, Operand encodedRagged, Long inputRaggedRank, Long outputRaggedRank, Class Tvalues) { + return create(scope, encodedRagged, inputRaggedRank, outputRaggedRank, Tvalues, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/Multinomial.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/Multinomial.java index 58cc57d6c52..fbfb90fba47 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/Multinomial.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/Multinomial.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Draws samples from a multinomial distribution. @@ -80,7 +78,7 @@ private Options() { * @return a new instance of Multinomial */ @Endpoint(describeByClass = true) - public static Multinomial create(Scope scope, Operand logits, Operand numSamples, DataType outputDtype, Options... options) { + public static Multinomial create(Scope scope, Operand logits, Operand numSamples, Class outputDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Multinomial", scope.makeOpName("Multinomial")); opBuilder.addInput(logits.asOutput()); opBuilder.addInput(numSamples.asOutput()); @@ -111,7 +109,7 @@ public static Multinomial create(Scope */ @Endpoint(describeByClass = true) public static Multinomial create(Scope scope, Operand logits, Operand numSamples, Options... options) { - return create(scope, logits, numSamples, TInt64.DTYPE, options); + return create(scope, logits, numSamples, TInt64.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/NonDeterministicInts.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/NonDeterministicInts.java index 246974eaf6f..80ee33ef06b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/NonDeterministicInts.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/NonDeterministicInts.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -47,7 +46,7 @@ public final class NonDeterministicInts extends RawOp implement * @return a new instance of NonDeterministicInts */ @Endpoint(describeByClass = true) - public static NonDeterministicInts create(Scope scope, Operand shape, DataType dtype) { + public static NonDeterministicInts create(Scope scope, Operand shape, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("NonDeterministicInts", scope.makeOpName("NonDeterministicInts")); opBuilder.addInput(shape.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -64,7 +63,7 @@ public static NonDeterministicInts create( */ @Endpoint(describeByClass = true) public static NonDeterministicInts create(Scope scope, Operand shape) { - return create(scope, shape, TInt64.DTYPE); + return create(scope, shape, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/ParameterizedTruncatedNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/ParameterizedTruncatedNormal.java index 4be50b9cde0..c9933d29a17 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/ParameterizedTruncatedNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/ParameterizedTruncatedNormal.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random values from a normal distribution. The parameters may each be a diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGamma.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGamma.java index 13963e09ecb..aee60a06932 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGamma.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGamma.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random values from the Gamma distribution(s) described by alpha. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGammaGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGammaGrad.java index ce3798cef3f..29c7c625e17 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGammaGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGammaGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the derivative of a Gamma random sample w.r.t. `alpha`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomPoisson.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomPoisson.java index d4b516343ac..04da9a55cdd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomPoisson.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomPoisson.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random values from the Poisson distribution(s) described by rate. @@ -91,7 +89,7 @@ private Options() { * @return a new instance of RandomPoisson */ @Endpoint(describeByClass = true) - public static RandomPoisson create(Scope scope, Operand shape, Operand rate, DataType dtype, Options... options) { + public static RandomPoisson create(Scope scope, Operand shape, Operand rate, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("RandomPoissonV2", scope.makeOpName("RandomPoisson")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(rate.asOutput()); @@ -123,7 +121,7 @@ public static RandomPo */ @Endpoint(describeByClass = true) public static RandomPoisson create(Scope scope, Operand shape, Operand rate, Options... options) { - return create(scope, shape, rate, TInt64.DTYPE, options); + return create(scope, shape, rate, TInt64.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomStandardNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomStandardNormal.java index bdd971cc19d..c9472874e46 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomStandardNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomStandardNormal.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -27,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random values from a normal distribution. @@ -79,7 +77,7 @@ private Options() { * @return a new instance of RandomStandardNormal */ @Endpoint(describeByClass = true) - public static RandomStandardNormal create(Scope scope, Operand shape, DataType dtype, Options... options) { + public static RandomStandardNormal create(Scope scope, Operand shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("RandomStandardNormal", scope.makeOpName("RandomStandardNormal")); opBuilder.addInput(shape.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniform.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniform.java index 5e42c2d9691..43952690b7d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniform.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniform.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -27,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random values from a uniform distribution. @@ -80,7 +78,7 @@ private Options() { * @return a new instance of RandomUniform */ @Endpoint(describeByClass = true) - public static RandomUniform create(Scope scope, Operand shape, DataType dtype, Options... options) { + public static RandomUniform create(Scope scope, Operand shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("RandomUniform", scope.makeOpName("RandomUniform")); opBuilder.addInput(shape.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniformInt.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniformInt.java index 5232135ac1c..1b59cf964e7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniformInt.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniformInt.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random integers from a uniform distribution. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulRandomBinomial.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulRandomBinomial.java index b3c2dfce166..f0d68949af1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulRandomBinomial.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulRandomBinomial.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output @@ -49,7 +47,7 @@ public final class StatefulRandomBinomial extends RawOp imple * @return a new instance of StatefulRandomBinomial */ @Endpoint(describeByClass = true) - public static StatefulRandomBinomial create(Scope scope, Operand resource, Operand algorithm, Operand shape, Operand counts, Operand probs, DataType dtype) { + public static StatefulRandomBinomial create(Scope scope, Operand resource, Operand algorithm, Operand shape, Operand counts, Operand probs, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatefulRandomBinomial", scope.makeOpName("StatefulRandomBinomial")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(algorithm.asOutput()); @@ -74,7 +72,7 @@ public static Stateful */ @Endpoint(describeByClass = true) public static StatefulRandomBinomial create(Scope scope, Operand resource, Operand algorithm, Operand shape, Operand counts, Operand probs) { - return create(scope, resource, algorithm, shape, counts, probs, TInt64.DTYPE); + return create(scope, resource, algorithm, shape, counts, probs, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulStandardNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulStandardNormal.java index 12aabfa2a73..a01c0f740b8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulStandardNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulStandardNormal.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -51,7 +50,7 @@ public final class StatefulStandardNormal extends RawOp impleme * @return a new instance of StatefulStandardNormal */ @Endpoint(describeByClass = true) - public static StatefulStandardNormal create(Scope scope, Operand resource, Operand algorithm, Operand shape, DataType dtype) { + public static StatefulStandardNormal create(Scope scope, Operand resource, Operand algorithm, Operand shape, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatefulStandardNormalV2", scope.makeOpName("StatefulStandardNormal")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(algorithm.asOutput()); @@ -72,7 +71,7 @@ public static StatefulStandardNormal creat */ @Endpoint(describeByClass = true) public static StatefulStandardNormal create(Scope scope, Operand resource, Operand algorithm, Operand shape) { - return create(scope, resource, algorithm, shape, TFloat32.DTYPE); + return create(scope, resource, algorithm, shape, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulTruncatedNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulTruncatedNormal.java index 86904de711f..461389eaec5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulTruncatedNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulTruncatedNormal.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -52,7 +51,7 @@ public final class StatefulTruncatedNormal extends RawOp implem * @return a new instance of StatefulTruncatedNormal */ @Endpoint(describeByClass = true) - public static StatefulTruncatedNormal create(Scope scope, Operand resource, Operand algorithm, Operand shape, DataType dtype) { + public static StatefulTruncatedNormal create(Scope scope, Operand resource, Operand algorithm, Operand shape, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatefulTruncatedNormal", scope.makeOpName("StatefulTruncatedNormal")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(algorithm.asOutput()); @@ -73,7 +72,7 @@ public static StatefulTruncatedNormal crea */ @Endpoint(describeByClass = true) public static StatefulTruncatedNormal create(Scope scope, Operand resource, Operand algorithm, Operand shape) { - return create(scope, resource, algorithm, shape, TFloat32.DTYPE); + return create(scope, resource, algorithm, shape, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniform.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniform.java index dd9a2c10af0..89fa8bdd32d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniform.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniform.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -51,7 +50,7 @@ public final class StatefulUniform extends RawOp implements Ope * @return a new instance of StatefulUniform */ @Endpoint(describeByClass = true) - public static StatefulUniform create(Scope scope, Operand resource, Operand algorithm, Operand shape, DataType dtype) { + public static StatefulUniform create(Scope scope, Operand resource, Operand algorithm, Operand shape, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatefulUniform", scope.makeOpName("StatefulUniform")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(algorithm.asOutput()); @@ -72,7 +71,7 @@ public static StatefulUniform create(Scope */ @Endpoint(describeByClass = true) public static StatefulUniform create(Scope scope, Operand resource, Operand algorithm, Operand shape) { - return create(scope, resource, algorithm, shape, TFloat32.DTYPE); + return create(scope, resource, algorithm, shape, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniformFullInt.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniformFullInt.java index 0bf991bd72e..457dcc0b1a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniformFullInt.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniformFullInt.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,7 +48,7 @@ public final class StatefulUniformFullInt extends RawOp impleme * @return a new instance of StatefulUniformFullInt */ @Endpoint(describeByClass = true) - public static StatefulUniformFullInt create(Scope scope, Operand resource, Operand algorithm, Operand shape, DataType dtype) { + public static StatefulUniformFullInt create(Scope scope, Operand resource, Operand algorithm, Operand shape, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatefulUniformFullInt", scope.makeOpName("StatefulUniformFullInt")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(algorithm.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessMultinomial.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessMultinomial.java index 1dbe0e5a7d0..3bc5ab365a9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessMultinomial.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessMultinomial.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Draws samples from a multinomial distribution. @@ -51,7 +49,7 @@ public final class StatelessMultinomial extends RawOp impleme * @return a new instance of StatelessMultinomial */ @Endpoint(describeByClass = true) - public static StatelessMultinomial create(Scope scope, Operand logits, Operand numSamples, Operand seed, DataType outputDtype) { + public static StatelessMultinomial create(Scope scope, Operand logits, Operand numSamples, Operand seed, Class outputDtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessMultinomial", scope.makeOpName("StatelessMultinomial")); opBuilder.addInput(logits.asOutput()); opBuilder.addInput(numSamples.asOutput()); @@ -73,7 +71,7 @@ public static Stateles */ @Endpoint(describeByClass = true) public static StatelessMultinomial create(Scope scope, Operand logits, Operand numSamples, Operand seed) { - return create(scope, logits, numSamples, seed, TInt64.DTYPE); + return create(scope, logits, numSamples, seed, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java index 179160463c7..f422a582a0f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomBinomial.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomBinomial.java index 03543495413..2c2fe32cf12 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomBinomial.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomBinomial.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom random numbers from a binomial distribution. @@ -55,7 +53,7 @@ public final class StatelessRandomBinomial extends RawOp impl * @return a new instance of StatelessRandomBinomial */ @Endpoint(describeByClass = true) - public static StatelessRandomBinomial create(Scope scope, Operand shape, Operand seed, Operand counts, Operand probs, DataType dtype) { + public static StatelessRandomBinomial create(Scope scope, Operand shape, Operand seed, Operand counts, Operand probs, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessRandomBinomial", scope.makeOpName("StatelessRandomBinomial")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(seed.asOutput()); @@ -80,7 +78,7 @@ public static StatelessRandomBinomial create(Scope scope, Operand shape, Operand seed, Operand counts, Operand probs) { - return create(scope, shape, seed, counts, probs, TInt64.DTYPE); + return create(scope, shape, seed, counts, probs, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomGamma.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomGamma.java index 2c4eef75ff7..7c8aabfa082 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomGamma.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomGamma.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom random numbers from a gamma distribution. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomNormal.java index 07c5298cca0..2b20afb21b1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomNormal.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom values from a normal distribution. @@ -52,7 +50,7 @@ public final class StatelessRandomNormal extends RawOp implem * @return a new instance of StatelessRandomNormal */ @Endpoint(describeByClass = true) - public static StatelessRandomNormal create(Scope scope, Operand shape, Operand seed, DataType dtype) { + public static StatelessRandomNormal create(Scope scope, Operand shape, Operand seed, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessRandomNormal", scope.makeOpName("StatelessRandomNormal")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(seed.asOutput()); @@ -71,7 +69,7 @@ public static Stateles */ @Endpoint(describeByClass = true) public static StatelessRandomNormal create(Scope scope, Operand shape, Operand seed) { - return create(scope, shape, seed, TFloat32.DTYPE); + return create(scope, shape, seed, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomPoisson.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomPoisson.java index e71c70e2c1f..cde1f6609b9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomPoisson.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomPoisson.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -27,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom random numbers from a Poisson distribution. @@ -52,7 +50,7 @@ public final class StatelessRandomPoisson extends RawOp imple * @return a new instance of StatelessRandomPoisson */ @Endpoint(describeByClass = true) - public static StatelessRandomPoisson create(Scope scope, Operand shape, Operand seed, Operand lam, DataType dtype) { + public static StatelessRandomPoisson create(Scope scope, Operand shape, Operand seed, Operand lam, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessRandomPoisson", scope.makeOpName("StatelessRandomPoisson")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(seed.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniform.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniform.java index 9eb6edc67e5..c94282163f6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniform.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniform.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom random values from a uniform distribution. @@ -53,7 +51,7 @@ public final class StatelessRandomUniform extends RawOp imple * @return a new instance of StatelessRandomUniform */ @Endpoint(describeByClass = true) - public static StatelessRandomUniform create(Scope scope, Operand shape, Operand seed, DataType dtype) { + public static StatelessRandomUniform create(Scope scope, Operand shape, Operand seed, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessRandomUniform", scope.makeOpName("StatelessRandomUniform")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(seed.asOutput()); @@ -72,7 +70,7 @@ public static Stateles */ @Endpoint(describeByClass = true) public static StatelessRandomUniform create(Scope scope, Operand shape, Operand seed) { - return create(scope, shape, seed, TFloat32.DTYPE); + return create(scope, shape, seed, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformFullInt.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformFullInt.java index 291cce74d19..b774c165ec0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformFullInt.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformFullInt.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -27,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom random integers from a uniform distribution. @@ -50,7 +48,7 @@ public final class StatelessRandomUniformFullInt extends RawO * @return a new instance of StatelessRandomUniformFullInt */ @Endpoint(describeByClass = true) - public static StatelessRandomUniformFullInt create(Scope scope, Operand shape, Operand seed, DataType dtype) { + public static StatelessRandomUniformFullInt create(Scope scope, Operand shape, Operand seed, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessRandomUniformFullInt", scope.makeOpName("StatelessRandomUniformFullInt")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(seed.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformInt.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformInt.java index 4695718a186..e8bbba18ab5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformInt.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformInt.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom random integers from a uniform distribution. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessTruncatedNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessTruncatedNormal.java index 3d4761fc4df..cf260db14a9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessTruncatedNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessTruncatedNormal.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -28,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom values from a truncated normal distribution. @@ -54,7 +52,7 @@ public final class StatelessTruncatedNormal extends RawOp imp * @return a new instance of StatelessTruncatedNormal */ @Endpoint(describeByClass = true) - public static StatelessTruncatedNormal create(Scope scope, Operand shape, Operand seed, DataType dtype) { + public static StatelessTruncatedNormal create(Scope scope, Operand shape, Operand seed, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessTruncatedNormal", scope.makeOpName("StatelessTruncatedNormal")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(seed.asOutput()); @@ -73,7 +71,7 @@ public static Stateles */ @Endpoint(describeByClass = true) public static StatelessTruncatedNormal create(Scope scope, Operand shape, Operand seed) { - return create(scope, shape, seed, TFloat32.DTYPE); + return create(scope, shape, seed, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/TruncatedNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/TruncatedNormal.java index 7b88d720386..dc7aee67856 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/TruncatedNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/TruncatedNormal.java @@ -17,7 +17,6 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -27,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random values from a truncated normal distribution. @@ -81,7 +79,7 @@ private Options() { * @return a new instance of TruncatedNormal */ @Endpoint(describeByClass = true) - public static TruncatedNormal create(Scope scope, Operand shape, DataType dtype, Options... options) { + public static TruncatedNormal create(Scope scope, Operand shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TruncatedNormal", scope.makeOpName("TruncatedNormal")); opBuilder.addInput(shape.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft.java index 80d3bb85291..b0441b47911 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft.java @@ -17,7 +17,6 @@ package org.tensorflow.op.signal; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -63,7 +62,7 @@ public final class Irfft extends RawOp implements Operand * @return a new instance of Irfft */ @Endpoint(describeByClass = true) - public static Irfft create(Scope scope, Operand input, Operand fftLength, DataType Treal) { + public static Irfft create(Scope scope, Operand input, Operand fftLength, Class Treal) { OperationBuilder opBuilder = scope.env().opBuilder("IRFFT", scope.makeOpName("Irfft")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(fftLength.asOutput()); @@ -82,7 +81,7 @@ public static Irfft create(Scope scope, */ @Endpoint(describeByClass = true) public static Irfft create(Scope scope, Operand input, Operand fftLength) { - return create(scope, input, fftLength, TFloat32.DTYPE); + return create(scope, input, fftLength, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft2d.java index 8acf23a4f23..726e1ea31e6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft2d.java @@ -17,7 +17,6 @@ package org.tensorflow.op.signal; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -64,7 +63,7 @@ public final class Irfft2d extends RawOp implements Operand Irfft2d create(Scope scope, Operand input, Operand fftLength, DataType Treal) { + public static Irfft2d create(Scope scope, Operand input, Operand fftLength, Class Treal) { OperationBuilder opBuilder = scope.env().opBuilder("IRFFT2D", scope.makeOpName("Irfft2d")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(fftLength.asOutput()); @@ -83,7 +82,7 @@ public static Irfft2d create(Scope scope */ @Endpoint(describeByClass = true) public static Irfft2d create(Scope scope, Operand input, Operand fftLength) { - return create(scope, input, fftLength, TFloat32.DTYPE); + return create(scope, input, fftLength, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft3d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft3d.java index c7b9efabfd2..0d76b5591e7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft3d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft3d.java @@ -17,7 +17,6 @@ package org.tensorflow.op.signal; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -64,7 +63,7 @@ public final class Irfft3d extends RawOp implements Operand Irfft3d create(Scope scope, Operand input, Operand fftLength, DataType Treal) { + public static Irfft3d create(Scope scope, Operand input, Operand fftLength, Class Treal) { OperationBuilder opBuilder = scope.env().opBuilder("IRFFT3D", scope.makeOpName("Irfft3d")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(fftLength.asOutput()); @@ -83,7 +82,7 @@ public static Irfft3d create(Scope scope */ @Endpoint(describeByClass = true) public static Irfft3d create(Scope scope, Operand input, Operand fftLength) { - return create(scope, input, fftLength, TFloat32.DTYPE); + return create(scope, input, fftLength, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft.java index 9764bcbf0f2..5d03d58a7b7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft.java @@ -17,7 +17,6 @@ package org.tensorflow.op.signal; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -59,7 +58,7 @@ public final class Rfft extends RawOp implements Operand { * @return a new instance of Rfft */ @Endpoint(describeByClass = true) - public static Rfft create(Scope scope, Operand input, Operand fftLength, DataType Tcomplex) { + public static Rfft create(Scope scope, Operand input, Operand fftLength, Class Tcomplex) { OperationBuilder opBuilder = scope.env().opBuilder("RFFT", scope.makeOpName("Rfft")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(fftLength.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft2d.java index 91187dced7b..ddecd0f133f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft2d.java @@ -17,7 +17,6 @@ package org.tensorflow.op.signal; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -60,7 +59,7 @@ public final class Rfft2d extends RawOp implements Operand { * @return a new instance of Rfft2d */ @Endpoint(describeByClass = true) - public static Rfft2d create(Scope scope, Operand input, Operand fftLength, DataType Tcomplex) { + public static Rfft2d create(Scope scope, Operand input, Operand fftLength, Class Tcomplex) { OperationBuilder opBuilder = scope.env().opBuilder("RFFT2D", scope.makeOpName("Rfft2d")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(fftLength.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft3d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft3d.java index 1eb113e9cf1..8fd625822e4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft3d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft3d.java @@ -17,7 +17,6 @@ package org.tensorflow.op.signal; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -60,7 +59,7 @@ public final class Rfft3d extends RawOp implements Operand { * @return a new instance of Rfft3d */ @Endpoint(describeByClass = true) - public static Rfft3d create(Scope scope, Operand input, Operand fftLength, DataType Tcomplex) { + public static Rfft3d create(Scope scope, Operand input, Operand fftLength, Class Tcomplex) { OperationBuilder opBuilder = scope.env().opBuilder("RFFT3D", scope.makeOpName("Rfft3d")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(fftLength.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java index ed390a7ba47..de6c4ef9eb0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs sparse-output bin counting for a tf.tensor input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DeserializeSparse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DeserializeSparse.java index a1a6bb2fa49..871910a0874 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DeserializeSparse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DeserializeSparse.java @@ -17,7 +17,6 @@ package org.tensorflow.op.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -89,7 +88,7 @@ public final class DeserializeSparse extends RawOp { * @return a new instance of DeserializeSparse */ @Endpoint(describeByClass = true) - public static DeserializeSparse create(Scope scope, Operand serializedSparse, DataType dtype) { + public static DeserializeSparse create(Scope scope, Operand serializedSparse, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("DeserializeSparse", scope.makeOpName("DeserializeSparse")); opBuilder.addInput(serializedSparse.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorTakeGradient.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorTakeGradient.java index 87cbd112e57..ebd3260f019 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorTakeGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorTakeGradient.java @@ -17,7 +17,6 @@ package org.tensorflow.op.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -57,7 +56,7 @@ public final class SparseAccumulatorTakeGradient extends RawOp * @return a new instance of SparseAccumulatorTakeGradient */ @Endpoint(describeByClass = true) - public static SparseAccumulatorTakeGradient create(Scope scope, Operand handle, Operand numRequired, DataType dtype) { + public static SparseAccumulatorTakeGradient create(Scope scope, Operand handle, Operand numRequired, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("SparseAccumulatorTakeGradient", scope.makeOpName("SparseAccumulatorTakeGradient")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(numRequired.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java index 344e27f1346..81d72b8a81e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Counts the number of occurrences of each value in an integer array. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseConditionalAccumulator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseConditionalAccumulator.java index b38fc0a6c46..863b2d0b8e3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseConditionalAccumulator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseConditionalAccumulator.java @@ -17,7 +17,6 @@ package org.tensorflow.op.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -92,7 +91,7 @@ private Options() { * @return a new instance of SparseConditionalAccumulator */ @Endpoint(describeByClass = true) - public static SparseConditionalAccumulator create(Scope scope, DataType dtype, Shape shape, Options... options) { + public static SparseConditionalAccumulator create(Scope scope, Class dtype, Shape shape, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("SparseConditionalAccumulator", scope.makeOpName("SparseConditionalAccumulator")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("dtype", dtype); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java index 5e5566db5ec..9f68897c075 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs sparse-output bin counting for a sparse tensor input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseMatMul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseMatMul.java index 295063d560a..fc66887070f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseMatMul.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseMatMul.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Multiply matrix "a" by matrix "b". diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMax.java index 5288f7b5b88..85c01083363 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMax.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the max of elements across dimensions of a SparseTensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMaxSparse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMaxSparse.java index 93ce28bb66a..2c2bbfb9348 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMaxSparse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMaxSparse.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the max of elements across dimensions of a SparseTensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMean.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMean.java index 14dd3df8c2f..72e5784e390 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMean.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMean.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the mean along sparse segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanGrad.java index b5ae14a0f03..7fc5d505ee6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients for SparseSegmentMean. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanWithNumSegments.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanWithNumSegments.java index 214d4983c81..9e2047047f8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanWithNumSegments.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanWithNumSegments.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the mean along sparse segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtN.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtN.java index c1889eae169..dccfcc0d1c2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtN.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtN.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the sum along sparse segments of a tensor divided by the sqrt of N. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNGrad.java index 3366dbac0c8..a9b7fb3023b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients for SparseSegmentSqrtN. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNWithNumSegments.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNWithNumSegments.java index 41000f65cdf..127ffa87198 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNWithNumSegments.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNWithNumSegments.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the sum along sparse segments of a tensor divided by the sqrt of N. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSum.java index 991653b2c02..321fbaed9ad 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSum.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the sum along sparse segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSumWithNumSegments.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSumWithNumSegments.java index f5414151cbe..830ca794891 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSumWithNumSegments.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSumWithNumSegments.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the sum along sparse segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSoftmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSoftmax.java index 0906213115f..81f0cd9bfe4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSoftmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSoftmax.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Applies softmax to a batched N-D `SparseTensor`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSparseMaximum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSparseMaximum.java index 3cbf1a6ef7d..976ac5b4633 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSparseMaximum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSparseMaximum.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the element-wise max of two SparseTensors. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/TakeManySparseFromTensorsMap.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/TakeManySparseFromTensorsMap.java index 7fd5242cb20..6e1bbb6ed27 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/TakeManySparseFromTensorsMap.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/TakeManySparseFromTensorsMap.java @@ -17,7 +17,6 @@ package org.tensorflow.op.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -124,7 +123,7 @@ private Options() { * @return a new instance of TakeManySparseFromTensorsMap */ @Endpoint(describeByClass = true) - public static TakeManySparseFromTensorsMap create(Scope scope, Operand sparseHandles, DataType dtype, Options... options) { + public static TakeManySparseFromTensorsMap create(Scope scope, Operand sparseHandles, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TakeManySparseFromTensorsMap", scope.makeOpName("TakeManySparseFromTensorsMap")); opBuilder.addInput(sparseHandles.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/StringNGrams.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/StringNGrams.java index 6d68c11fe1b..93d8b60ad87 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/StringNGrams.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/StringNGrams.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Creates ngrams from ragged string data. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/Substr.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/Substr.java index f1eb5a05485..d09e93458ba 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/Substr.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/Substr.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Return substrings from `Tensor` of strings. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/ToNumber.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/ToNumber.java index e8e5e4039b9..47ddcfde480 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/ToNumber.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/ToNumber.java @@ -17,7 +17,6 @@ package org.tensorflow.op.strings; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Converts each string in the input Tensor to the specified numeric type. @@ -58,7 +56,7 @@ public final class ToNumber extends RawOp implements Operand< * @return a new instance of ToNumber */ @Endpoint(describeByClass = true) - public static ToNumber create(Scope scope, Operand stringTensor, DataType outType) { + public static ToNumber create(Scope scope, Operand stringTensor, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("StringToNumber", scope.makeOpName("ToNumber")); opBuilder.addInput(stringTensor.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -75,7 +73,7 @@ public static ToNumber create(Scope scope, Operand create(Scope scope, Operand stringTensor) { - return create(scope, stringTensor, TFloat32.DTYPE); + return create(scope, stringTensor, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecode.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecode.java index a39b08b1f5a..1c73739a144 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecode.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecode.java @@ -17,7 +17,6 @@ package org.tensorflow.op.strings; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -30,7 +29,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Decodes each string in `input` into a sequence of Unicode code points. @@ -116,7 +114,7 @@ private Options() { * @return a new instance of UnicodeDecode */ @Endpoint(describeByClass = true) - public static UnicodeDecode create(Scope scope, Operand input, String inputEncoding, DataType Tsplits, Options... options) { + public static UnicodeDecode create(Scope scope, Operand input, String inputEncoding, Class Tsplits, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("UnicodeDecode", scope.makeOpName("UnicodeDecode")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -151,7 +149,7 @@ public static UnicodeDecode create(Scope scope, Operand create(Scope scope, Operand input, String inputEncoding, Options... options) { - return create(scope, input, inputEncoding, TInt64.DTYPE, options); + return create(scope, input, inputEncoding, TInt64.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecodeWithOffsets.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecodeWithOffsets.java index ce6977b63ef..f0999df6285 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecodeWithOffsets.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecodeWithOffsets.java @@ -17,7 +17,6 @@ package org.tensorflow.op.strings; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -30,7 +29,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Decodes each string in `input` into a sequence of Unicode code points. @@ -122,7 +120,7 @@ private Options() { * @return a new instance of UnicodeDecodeWithOffsets */ @Endpoint(describeByClass = true) - public static UnicodeDecodeWithOffsets create(Scope scope, Operand input, String inputEncoding, DataType Tsplits, Options... options) { + public static UnicodeDecodeWithOffsets create(Scope scope, Operand input, String inputEncoding, Class Tsplits, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("UnicodeDecodeWithOffsets", scope.makeOpName("UnicodeDecodeWithOffsets")); opBuilder.addInput(input.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); @@ -157,7 +155,7 @@ public static UnicodeDecodeWithOffsets create(Scope scope */ @Endpoint(describeByClass = true) public static UnicodeDecodeWithOffsets create(Scope scope, Operand input, String inputEncoding, Options... options) { - return create(scope, input, inputEncoding, TInt64.DTYPE, options); + return create(scope, input, inputEncoding, TInt64.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeEncode.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeEncode.java index ab83ef9dcda..e499dee3157 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeEncode.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeEncode.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Encode a tensor of ints into unicode strings. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnsortedSegmentJoin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnsortedSegmentJoin.java index df1bbc68b68..2cd28375bd3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnsortedSegmentJoin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnsortedSegmentJoin.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Joins the elements of `inputs` based on `segment_ids`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/HistogramSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/HistogramSummary.java index 669cbf5c2fb..b1bf74ba9d0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/HistogramSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/HistogramSummary.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs a `Summary` protocol buffer with a histogram. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java index 5b0f9f8859f..b97fbfcf8da 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java @@ -21,7 +21,6 @@ import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; -import org.tensorflow.Tensor; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -99,13 +98,13 @@ public Options maxImages(Long maxImages) { /** * @param badColor Color to use for pixels with non-finite values. */ - public Options badColor(Tensor badColor) { + public Options badColor(TType badColor) { this.badColor = badColor; return this; } private Long maxImages; - private Tensor badColor; + private TType badColor; private Options() { } @@ -150,7 +149,7 @@ public static Options maxImages(Long maxImages) { /** * @param badColor Color to use for pixels with non-finite values. */ - public static Options badColor(Tensor badColor) { + public static Options badColor(TType badColor) { return new Options().badColor(badColor); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ScalarSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ScalarSummary.java index 416251aa6b1..20baf84097f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ScalarSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ScalarSummary.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs a `Summary` protocol buffer with scalar values. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteHistogramSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteHistogramSummary.java index 2069cefafa4..e460ecc8ca2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteHistogramSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteHistogramSummary.java @@ -27,7 +27,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** */ diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java index 757ddf59a1c..ff10607477b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TString; import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** */ diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java index f173651001a..af1291b964f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java @@ -27,7 +27,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** */ diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/CrossReplicaSum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/CrossReplicaSum.java index 79dc79410a1..d4f982dce61 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/CrossReplicaSum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/CrossReplicaSum.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * An Op to sum inputs across replicated TPU instances. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java index bf4da86d05d..ea2961fa2b8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Eases the porting of code that uses tf.nn.embedding_lookup(). diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java index 2cb7dfb674b..67c4d41ea03 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * An op that enqueues TPUEmbedding input indices from a SparseTensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java index 3d93c6a0f71..df9ea11bf37 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Eases the porting of code that uses tf.nn.embedding_lookup_sparse(). diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeue.java index ad4d5f52fa8..1713f1f1a4d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeue.java @@ -17,7 +17,6 @@ package org.tensorflow.op.tpu; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -45,7 +44,7 @@ public final class InfeedDequeue extends RawOp implements Opera * @return a new instance of InfeedDequeue */ @Endpoint(describeByClass = true) - public static InfeedDequeue create(Scope scope, DataType dtype, Shape shape) { + public static InfeedDequeue create(Scope scope, Class dtype, Shape shape) { OperationBuilder opBuilder = scope.env().opBuilder("InfeedDequeue", scope.makeOpName("InfeedDequeue")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("dtype", dtype); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeueTuple.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeueTuple.java index f471da5b5d7..b90665f21c5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeueTuple.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeueTuple.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,10 +45,10 @@ public final class InfeedDequeueTuple extends RawOp implements Iterable> dtypes, List shapes) { + public static InfeedDequeueTuple create(Scope scope, List> dtypes, List shapes) { OperationBuilder opBuilder = scope.env().opBuilder("InfeedDequeueTuple", scope.makeOpName("InfeedDequeueTuple")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeue.java index 3be811cd039..c11a81f1e54 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeue.java @@ -17,7 +17,6 @@ package org.tensorflow.op.tpu; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -69,7 +68,7 @@ private Options() { * @return a new instance of OutfeedDequeue */ @Endpoint(describeByClass = true) - public static OutfeedDequeue create(Scope scope, DataType dtype, Shape shape, Options... options) { + public static OutfeedDequeue create(Scope scope, Class dtype, Shape shape, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OutfeedDequeue", scope.makeOpName("OutfeedDequeue")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("dtype", dtype); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeueTuple.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeueTuple.java index 6b9110232d8..284c2d7a0fb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeueTuple.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeueTuple.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -71,10 +70,10 @@ private Options() { * @return a new instance of OutfeedDequeueTuple */ @Endpoint(describeByClass = true) - public static OutfeedDequeueTuple create(Scope scope, List> dtypes, List shapes, Options... options) { + public static OutfeedDequeueTuple create(Scope scope, List> dtypes, List shapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OutfeedDequeueTuple", scope.makeOpName("OutfeedDequeueTuple")); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorTakeGradient.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorTakeGradient.java index 51de5485845..c5371e853c6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorTakeGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorTakeGradient.java @@ -17,7 +17,6 @@ package org.tensorflow.op.train; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -55,7 +54,7 @@ public final class AccumulatorTakeGradient extends RawOp implem * @return a new instance of AccumulatorTakeGradient */ @Endpoint(describeByClass = true) - public static AccumulatorTakeGradient create(Scope scope, Operand handle, Operand numRequired, DataType dtype) { + public static AccumulatorTakeGradient create(Scope scope, Operand handle, Operand numRequired, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("AccumulatorTakeGradient", scope.makeOpName("AccumulatorTakeGradient")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(numRequired.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ConditionalAccumulator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ConditionalAccumulator.java index f35e9f5001f..aaf8a6f767a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ConditionalAccumulator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ConditionalAccumulator.java @@ -17,7 +17,6 @@ package org.tensorflow.op.train; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -92,7 +91,7 @@ private Options() { * @return a new instance of ConditionalAccumulator */ @Endpoint(describeByClass = true) - public static ConditionalAccumulator create(Scope scope, DataType dtype, Shape shape, Options... options) { + public static ConditionalAccumulator create(Scope scope, Class dtype, Shape shape, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ConditionalAccumulator", scope.makeOpName("ConditionalAccumulator")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("dtype", dtype); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorTakeGradient.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorTakeGradient.java index c27fb1db524..75601de5579 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorTakeGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorTakeGradient.java @@ -17,7 +17,6 @@ package org.tensorflow.op.train; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -53,7 +52,7 @@ public final class ResourceAccumulatorTakeGradient extends RawO * @return a new instance of ResourceAccumulatorTakeGradient */ @Endpoint(describeByClass = true) - public static ResourceAccumulatorTakeGradient create(Scope scope, Operand handle, Operand numRequired, DataType dtype) { + public static ResourceAccumulatorTakeGradient create(Scope scope, Operand handle, Operand numRequired, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("ResourceAccumulatorTakeGradient", scope.makeOpName("ResourceAccumulatorTakeGradient")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(numRequired.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceConditionalAccumulator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceConditionalAccumulator.java index 2935d3ca6bc..9bc3453a72a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceConditionalAccumulator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceConditionalAccumulator.java @@ -17,7 +17,6 @@ package org.tensorflow.op.train; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -92,7 +91,7 @@ private Options() { * @return a new instance of ResourceConditionalAccumulator */ @Endpoint(describeByClass = true) - public static ResourceConditionalAccumulator create(Scope scope, DataType dtype, Shape shape, Options... options) { + public static ResourceConditionalAccumulator create(Scope scope, Class dtype, Shape shape, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ResourceConditionalAccumulator", scope.makeOpName("ResourceConditionalAccumulator")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("dtype", dtype); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Restore.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Restore.java index 5a2466f501d..721e40e058a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Restore.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Restore.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -65,13 +64,13 @@ public final class Restore extends RawOp implements Iterable> { * @return a new instance of Restore */ @Endpoint(describeByClass = true) - public static Restore create(Scope scope, Operand prefix, Operand tensorNames, Operand shapeAndSlices, List> dtypes) { + public static Restore create(Scope scope, Operand prefix, Operand tensorNames, Operand shapeAndSlices, List> dtypes) { OperationBuilder opBuilder = scope.env().opBuilder("RestoreV2", scope.makeOpName("Restore")); opBuilder.addInput(prefix.asOutput()); opBuilder.addInput(tensorNames.asOutput()); opBuilder.addInput(shapeAndSlices.asOutput()); opBuilder = scope.applyControlDependencies(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; + Class[] dtypesArray = new Class[dtypes.size()]; for (int i = 0; i < dtypesArray.length; ++i) { dtypesArray[i] = dtypes.get(i); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/RestoreSlice.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/RestoreSlice.java index 631a127a44f..24b13ac30af 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/RestoreSlice.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/RestoreSlice.java @@ -17,7 +17,6 @@ package org.tensorflow.op.train; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -79,7 +78,7 @@ private Options() { * @return a new instance of RestoreSlice */ @Endpoint(describeByClass = true) - public static RestoreSlice create(Scope scope, Operand filePattern, Operand tensorName, Operand shapeAndSlice, DataType dt, Options... options) { + public static RestoreSlice create(Scope scope, Operand filePattern, Operand tensorName, Operand shapeAndSlice, Class dt, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("RestoreSlice", scope.makeOpName("RestoreSlice")); opBuilder.addInput(filePattern.asOutput()); opBuilder.addInput(tensorName.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Recv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Recv.java index fe26ee27587..63cdcc2767b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Recv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Recv.java @@ -17,7 +17,6 @@ package org.tensorflow.op.xla; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -50,7 +49,7 @@ public final class Recv extends RawOp implements Operand { * @return a new instance of Recv */ @Endpoint(describeByClass = true) - public static Recv create(Scope scope, DataType dtype, String tensorName, Shape shape) { + public static Recv create(Scope scope, Class dtype, String tensorName, Shape shape) { OperationBuilder opBuilder = scope.env().opBuilder("XlaRecv", scope.makeOpName("Recv")); opBuilder = scope.applyControlDependencies(opBuilder); opBuilder.setAttr("dtype", dtype); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/Device.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/Device.java deleted file mode 100644 index 59507a00680..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/Device.java +++ /dev/null @@ -1,977 +0,0 @@ -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: tensorflow/core/protobuf/trace_events.proto - -package org.tensorflow.proto.framework; - -/** - *

- * A 'device' is a physical entity in the system and is comprised of several
- * resources.
- * 
- * - * Protobuf type {@code tensorflow.profiler.Device} - */ -public final class Device extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:tensorflow.profiler.Device) - DeviceOrBuilder { -private static final long serialVersionUID = 0L; - // Use Device.newBuilder() to construct. - private Device(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } - private Device() { - name_ = ""; - } - - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new Device(); - } - - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } - private Device( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 10: { - java.lang.String s = input.readStringRequireUtf8(); - - name_ = s; - break; - } - case 16: { - - deviceId_ = input.readUInt32(); - break; - } - case 26: { - if (!((mutable_bitField0_ & 0x00000001) != 0)) { - resources_ = com.google.protobuf.MapField.newMapField( - ResourcesDefaultEntryHolder.defaultEntry); - mutable_bitField0_ |= 0x00000001; - } - com.google.protobuf.MapEntry - resources__ = input.readMessage( - ResourcesDefaultEntryHolder.defaultEntry.getParserForType(), extensionRegistry); - resources_.getMutableMap().put( - resources__.getKey(), resources__.getValue()); - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Device_descriptor; - } - - @SuppressWarnings({"rawtypes"}) - @java.lang.Override - protected com.google.protobuf.MapField internalGetMapField( - int number) { - switch (number) { - case 3: - return internalGetResources(); - default: - throw new RuntimeException( - "Invalid map field number: " + number); - } - } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Device_fieldAccessorTable - .ensureFieldAccessorsInitialized( - org.tensorflow.proto.framework.Device.class, org.tensorflow.proto.framework.Device.Builder.class); - } - - public static final int NAME_FIELD_NUMBER = 1; - private volatile java.lang.Object name_; - /** - *
-   * The name of the device.
-   * 
- * - * string name = 1; - */ - public java.lang.String getName() { - java.lang.Object ref = name_; - if (ref instanceof java.lang.String) { - return (java.lang.String) ref; - } else { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - name_ = s; - return s; - } - } - /** - *
-   * The name of the device.
-   * 
- * - * string name = 1; - */ - public com.google.protobuf.ByteString - getNameBytes() { - java.lang.Object ref = name_; - if (ref instanceof java.lang.String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - name_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } - } - - public static final int DEVICE_ID_FIELD_NUMBER = 2; - private int deviceId_; - /** - *
-   * The id of this device, unique in a single trace.
-   * 
- * - * uint32 device_id = 2; - */ - public int getDeviceId() { - return deviceId_; - } - - public static final int RESOURCES_FIELD_NUMBER = 3; - private static final class ResourcesDefaultEntryHolder { - static final com.google.protobuf.MapEntry< - java.lang.Integer, org.tensorflow.proto.framework.Resource> defaultEntry = - com.google.protobuf.MapEntry - .newDefaultInstance( - org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Device_ResourcesEntry_descriptor, - com.google.protobuf.WireFormat.FieldType.UINT32, - 0, - com.google.protobuf.WireFormat.FieldType.MESSAGE, - org.tensorflow.proto.framework.Resource.getDefaultInstance()); - } - private com.google.protobuf.MapField< - java.lang.Integer, org.tensorflow.proto.framework.Resource> resources_; - private com.google.protobuf.MapField - internalGetResources() { - if (resources_ == null) { - return com.google.protobuf.MapField.emptyMapField( - ResourcesDefaultEntryHolder.defaultEntry); - } - return resources_; - } - - public int getResourcesCount() { - return internalGetResources().getMap().size(); - } - /** - *
-   * The resources on this device, keyed by resource_id;
-   * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - - public boolean containsResources( - int key) { - - return internalGetResources().getMap().containsKey(key); - } - /** - * Use {@link #getResourcesMap()} instead. - */ - @java.lang.Deprecated - public java.util.Map getResources() { - return getResourcesMap(); - } - /** - *
-   * The resources on this device, keyed by resource_id;
-   * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - - public java.util.Map getResourcesMap() { - return internalGetResources().getMap(); - } - /** - *
-   * The resources on this device, keyed by resource_id;
-   * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - - public org.tensorflow.proto.framework.Resource getResourcesOrDefault( - int key, - org.tensorflow.proto.framework.Resource defaultValue) { - - java.util.Map map = - internalGetResources().getMap(); - return map.containsKey(key) ? map.get(key) : defaultValue; - } - /** - *
-   * The resources on this device, keyed by resource_id;
-   * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - - public org.tensorflow.proto.framework.Resource getResourcesOrThrow( - int key) { - - java.util.Map map = - internalGetResources().getMap(); - if (!map.containsKey(key)) { - throw new java.lang.IllegalArgumentException(); - } - return map.get(key); - } - - private byte memoizedIsInitialized = -1; - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; - - memoizedIsInitialized = 1; - return true; - } - - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - if (!getNameBytes().isEmpty()) { - com.google.protobuf.GeneratedMessageV3.writeString(output, 1, name_); - } - if (deviceId_ != 0) { - output.writeUInt32(2, deviceId_); - } - com.google.protobuf.GeneratedMessageV3 - .serializeIntegerMapTo( - output, - internalGetResources(), - ResourcesDefaultEntryHolder.defaultEntry, - 3); - unknownFields.writeTo(output); - } - - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; - - size = 0; - if (!getNameBytes().isEmpty()) { - size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, name_); - } - if (deviceId_ != 0) { - size += com.google.protobuf.CodedOutputStream - .computeUInt32Size(2, deviceId_); - } - for (java.util.Map.Entry entry - : internalGetResources().getMap().entrySet()) { - com.google.protobuf.MapEntry - resources__ = ResourcesDefaultEntryHolder.defaultEntry.newBuilderForType() - .setKey(entry.getKey()) - .setValue(entry.getValue()) - .build(); - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(3, resources__); - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } - - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof org.tensorflow.proto.framework.Device)) { - return super.equals(obj); - } - org.tensorflow.proto.framework.Device other = (org.tensorflow.proto.framework.Device) obj; - - if (!getName() - .equals(other.getName())) return false; - if (getDeviceId() - != other.getDeviceId()) return false; - if (!internalGetResources().equals( - other.internalGetResources())) return false; - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } - - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - hash = (37 * hash) + NAME_FIELD_NUMBER; - hash = (53 * hash) + getName().hashCode(); - hash = (37 * hash) + DEVICE_ID_FIELD_NUMBER; - hash = (53 * hash) + getDeviceId(); - if (!internalGetResources().getMap().isEmpty()) { - hash = (37 * hash) + RESOURCES_FIELD_NUMBER; - hash = (53 * hash) + internalGetResources().hashCode(); - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } - - public static org.tensorflow.proto.framework.Device parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static org.tensorflow.proto.framework.Device parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static org.tensorflow.proto.framework.Device parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static org.tensorflow.proto.framework.Device parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static org.tensorflow.proto.framework.Device parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static org.tensorflow.proto.framework.Device parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static org.tensorflow.proto.framework.Device parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - public static org.tensorflow.proto.framework.Device parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - public static org.tensorflow.proto.framework.Device parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } - public static org.tensorflow.proto.framework.Device parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } - public static org.tensorflow.proto.framework.Device parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - public static org.tensorflow.proto.framework.Device parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - - @java.lang.Override - public Builder newBuilderForType() { return newBuilder(); } - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } - public static Builder newBuilder(org.tensorflow.proto.framework.Device prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } - - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } - /** - *
-   * A 'device' is a physical entity in the system and is comprised of several
-   * resources.
-   * 
- * - * Protobuf type {@code tensorflow.profiler.Device} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:tensorflow.profiler.Device) - org.tensorflow.proto.framework.DeviceOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Device_descriptor; - } - - @SuppressWarnings({"rawtypes"}) - protected com.google.protobuf.MapField internalGetMapField( - int number) { - switch (number) { - case 3: - return internalGetResources(); - default: - throw new RuntimeException( - "Invalid map field number: " + number); - } - } - @SuppressWarnings({"rawtypes"}) - protected com.google.protobuf.MapField internalGetMutableMapField( - int number) { - switch (number) { - case 3: - return internalGetMutableResources(); - default: - throw new RuntimeException( - "Invalid map field number: " + number); - } - } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Device_fieldAccessorTable - .ensureFieldAccessorsInitialized( - org.tensorflow.proto.framework.Device.class, org.tensorflow.proto.framework.Device.Builder.class); - } - - // Construct using org.tensorflow.proto.framework.Device.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } - - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } - private void maybeForceBuilderInitialization() { - if (com.google.protobuf.GeneratedMessageV3 - .alwaysUseFieldBuilders) { - } - } - @java.lang.Override - public Builder clear() { - super.clear(); - name_ = ""; - - deviceId_ = 0; - - internalGetMutableResources().clear(); - return this; - } - - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Device_descriptor; - } - - @java.lang.Override - public org.tensorflow.proto.framework.Device getDefaultInstanceForType() { - return org.tensorflow.proto.framework.Device.getDefaultInstance(); - } - - @java.lang.Override - public org.tensorflow.proto.framework.Device build() { - org.tensorflow.proto.framework.Device result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } - - @java.lang.Override - public org.tensorflow.proto.framework.Device buildPartial() { - org.tensorflow.proto.framework.Device result = new org.tensorflow.proto.framework.Device(this); - int from_bitField0_ = bitField0_; - result.name_ = name_; - result.deviceId_ = deviceId_; - result.resources_ = internalGetResources(); - result.resources_.makeImmutable(); - onBuilt(); - return result; - } - - @java.lang.Override - public Builder clone() { - return super.clone(); - } - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); - } - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); - } - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); - } - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.tensorflow.proto.framework.Device) { - return mergeFrom((org.tensorflow.proto.framework.Device)other); - } else { - super.mergeFrom(other); - return this; - } - } - - public Builder mergeFrom(org.tensorflow.proto.framework.Device other) { - if (other == org.tensorflow.proto.framework.Device.getDefaultInstance()) return this; - if (!other.getName().isEmpty()) { - name_ = other.name_; - onChanged(); - } - if (other.getDeviceId() != 0) { - setDeviceId(other.getDeviceId()); - } - internalGetMutableResources().mergeFrom( - other.internalGetResources()); - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } - - @java.lang.Override - public final boolean isInitialized() { - return true; - } - - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - org.tensorflow.proto.framework.Device parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (org.tensorflow.proto.framework.Device) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } - private int bitField0_; - - private java.lang.Object name_ = ""; - /** - *
-     * The name of the device.
-     * 
- * - * string name = 1; - */ - public java.lang.String getName() { - java.lang.Object ref = name_; - if (!(ref instanceof java.lang.String)) { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - name_ = s; - return s; - } else { - return (java.lang.String) ref; - } - } - /** - *
-     * The name of the device.
-     * 
- * - * string name = 1; - */ - public com.google.protobuf.ByteString - getNameBytes() { - java.lang.Object ref = name_; - if (ref instanceof String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - name_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } - } - /** - *
-     * The name of the device.
-     * 
- * - * string name = 1; - */ - public Builder setName( - java.lang.String value) { - if (value == null) { - throw new NullPointerException(); - } - - name_ = value; - onChanged(); - return this; - } - /** - *
-     * The name of the device.
-     * 
- * - * string name = 1; - */ - public Builder clearName() { - - name_ = getDefaultInstance().getName(); - onChanged(); - return this; - } - /** - *
-     * The name of the device.
-     * 
- * - * string name = 1; - */ - public Builder setNameBytes( - com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - checkByteStringIsUtf8(value); - - name_ = value; - onChanged(); - return this; - } - - private int deviceId_ ; - /** - *
-     * The id of this device, unique in a single trace.
-     * 
- * - * uint32 device_id = 2; - */ - public int getDeviceId() { - return deviceId_; - } - /** - *
-     * The id of this device, unique in a single trace.
-     * 
- * - * uint32 device_id = 2; - */ - public Builder setDeviceId(int value) { - - deviceId_ = value; - onChanged(); - return this; - } - /** - *
-     * The id of this device, unique in a single trace.
-     * 
- * - * uint32 device_id = 2; - */ - public Builder clearDeviceId() { - - deviceId_ = 0; - onChanged(); - return this; - } - - private com.google.protobuf.MapField< - java.lang.Integer, org.tensorflow.proto.framework.Resource> resources_; - private com.google.protobuf.MapField - internalGetResources() { - if (resources_ == null) { - return com.google.protobuf.MapField.emptyMapField( - ResourcesDefaultEntryHolder.defaultEntry); - } - return resources_; - } - private com.google.protobuf.MapField - internalGetMutableResources() { - onChanged();; - if (resources_ == null) { - resources_ = com.google.protobuf.MapField.newMapField( - ResourcesDefaultEntryHolder.defaultEntry); - } - if (!resources_.isMutable()) { - resources_ = resources_.copy(); - } - return resources_; - } - - public int getResourcesCount() { - return internalGetResources().getMap().size(); - } - /** - *
-     * The resources on this device, keyed by resource_id;
-     * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - - public boolean containsResources( - int key) { - - return internalGetResources().getMap().containsKey(key); - } - /** - * Use {@link #getResourcesMap()} instead. - */ - @java.lang.Deprecated - public java.util.Map getResources() { - return getResourcesMap(); - } - /** - *
-     * The resources on this device, keyed by resource_id;
-     * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - - public java.util.Map getResourcesMap() { - return internalGetResources().getMap(); - } - /** - *
-     * The resources on this device, keyed by resource_id;
-     * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - - public org.tensorflow.proto.framework.Resource getResourcesOrDefault( - int key, - org.tensorflow.proto.framework.Resource defaultValue) { - - java.util.Map map = - internalGetResources().getMap(); - return map.containsKey(key) ? map.get(key) : defaultValue; - } - /** - *
-     * The resources on this device, keyed by resource_id;
-     * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - - public org.tensorflow.proto.framework.Resource getResourcesOrThrow( - int key) { - - java.util.Map map = - internalGetResources().getMap(); - if (!map.containsKey(key)) { - throw new java.lang.IllegalArgumentException(); - } - return map.get(key); - } - - public Builder clearResources() { - internalGetMutableResources().getMutableMap() - .clear(); - return this; - } - /** - *
-     * The resources on this device, keyed by resource_id;
-     * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - - public Builder removeResources( - int key) { - - internalGetMutableResources().getMutableMap() - .remove(key); - return this; - } - /** - * Use alternate mutation accessors instead. - */ - @java.lang.Deprecated - public java.util.Map - getMutableResources() { - return internalGetMutableResources().getMutableMap(); - } - /** - *
-     * The resources on this device, keyed by resource_id;
-     * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - public Builder putResources( - int key, - org.tensorflow.proto.framework.Resource value) { - - if (value == null) { throw new java.lang.NullPointerException(); } - internalGetMutableResources().getMutableMap() - .put(key, value); - return this; - } - /** - *
-     * The resources on this device, keyed by resource_id;
-     * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - - public Builder putAllResources( - java.util.Map values) { - internalGetMutableResources().getMutableMap() - .putAll(values); - return this; - } - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } - - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } - - - // @@protoc_insertion_point(builder_scope:tensorflow.profiler.Device) - } - - // @@protoc_insertion_point(class_scope:tensorflow.profiler.Device) - private static final org.tensorflow.proto.framework.Device DEFAULT_INSTANCE; - static { - DEFAULT_INSTANCE = new org.tensorflow.proto.framework.Device(); - } - - public static org.tensorflow.proto.framework.Device getDefaultInstance() { - return DEFAULT_INSTANCE; - } - - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public Device parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new Device(input, extensionRegistry); - } - }; - - public static com.google.protobuf.Parser parser() { - return PARSER; - } - - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } - - @java.lang.Override - public org.tensorflow.proto.framework.Device getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } - -} - diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/DeviceOrBuilder.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/DeviceOrBuilder.java deleted file mode 100644 index 3ce193459ff..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/DeviceOrBuilder.java +++ /dev/null @@ -1,90 +0,0 @@ -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: tensorflow/core/protobuf/trace_events.proto - -package org.tensorflow.proto.framework; - -public interface DeviceOrBuilder extends - // @@protoc_insertion_point(interface_extends:tensorflow.profiler.Device) - com.google.protobuf.MessageOrBuilder { - - /** - *
-   * The name of the device.
-   * 
- * - * string name = 1; - */ - java.lang.String getName(); - /** - *
-   * The name of the device.
-   * 
- * - * string name = 1; - */ - com.google.protobuf.ByteString - getNameBytes(); - - /** - *
-   * The id of this device, unique in a single trace.
-   * 
- * - * uint32 device_id = 2; - */ - int getDeviceId(); - - /** - *
-   * The resources on this device, keyed by resource_id;
-   * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - int getResourcesCount(); - /** - *
-   * The resources on this device, keyed by resource_id;
-   * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - boolean containsResources( - int key); - /** - * Use {@link #getResourcesMap()} instead. - */ - @java.lang.Deprecated - java.util.Map - getResources(); - /** - *
-   * The resources on this device, keyed by resource_id;
-   * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - java.util.Map - getResourcesMap(); - /** - *
-   * The resources on this device, keyed by resource_id;
-   * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - - org.tensorflow.proto.framework.Resource getResourcesOrDefault( - int key, - org.tensorflow.proto.framework.Resource defaultValue); - /** - *
-   * The resources on this device, keyed by resource_id;
-   * 
- * - * map<uint32, .tensorflow.profiler.Resource> resources = 3; - */ - - org.tensorflow.proto.framework.Resource getResourcesOrThrow( - int key); -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/Resource.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/Resource.java deleted file mode 100644 index 7743aecf24e..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/Resource.java +++ /dev/null @@ -1,659 +0,0 @@ -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: tensorflow/core/protobuf/trace_events.proto - -package org.tensorflow.proto.framework; - -/** - *
- * A 'resource' generally is a specific computation component on a device. These
- * can range from threads on CPUs to specific arithmetic units on hardware
- * devices.
- * 
- * - * Protobuf type {@code tensorflow.profiler.Resource} - */ -public final class Resource extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:tensorflow.profiler.Resource) - ResourceOrBuilder { -private static final long serialVersionUID = 0L; - // Use Resource.newBuilder() to construct. - private Resource(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } - private Resource() { - name_ = ""; - } - - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new Resource(); - } - - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } - private Resource( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 10: { - java.lang.String s = input.readStringRequireUtf8(); - - name_ = s; - break; - } - case 16: { - - resourceId_ = input.readUInt32(); - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Resource_descriptor; - } - - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Resource_fieldAccessorTable - .ensureFieldAccessorsInitialized( - org.tensorflow.proto.framework.Resource.class, org.tensorflow.proto.framework.Resource.Builder.class); - } - - public static final int NAME_FIELD_NUMBER = 1; - private volatile java.lang.Object name_; - /** - *
-   * The name of the resource.
-   * 
- * - * string name = 1; - */ - public java.lang.String getName() { - java.lang.Object ref = name_; - if (ref instanceof java.lang.String) { - return (java.lang.String) ref; - } else { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - name_ = s; - return s; - } - } - /** - *
-   * The name of the resource.
-   * 
- * - * string name = 1; - */ - public com.google.protobuf.ByteString - getNameBytes() { - java.lang.Object ref = name_; - if (ref instanceof java.lang.String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - name_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } - } - - public static final int RESOURCE_ID_FIELD_NUMBER = 2; - private int resourceId_; - /** - *
-   * The id of the resource. Unique within a device.
-   * 
- * - * uint32 resource_id = 2; - */ - public int getResourceId() { - return resourceId_; - } - - private byte memoizedIsInitialized = -1; - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; - - memoizedIsInitialized = 1; - return true; - } - - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - if (!getNameBytes().isEmpty()) { - com.google.protobuf.GeneratedMessageV3.writeString(output, 1, name_); - } - if (resourceId_ != 0) { - output.writeUInt32(2, resourceId_); - } - unknownFields.writeTo(output); - } - - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; - - size = 0; - if (!getNameBytes().isEmpty()) { - size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, name_); - } - if (resourceId_ != 0) { - size += com.google.protobuf.CodedOutputStream - .computeUInt32Size(2, resourceId_); - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } - - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof org.tensorflow.proto.framework.Resource)) { - return super.equals(obj); - } - org.tensorflow.proto.framework.Resource other = (org.tensorflow.proto.framework.Resource) obj; - - if (!getName() - .equals(other.getName())) return false; - if (getResourceId() - != other.getResourceId()) return false; - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } - - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - hash = (37 * hash) + NAME_FIELD_NUMBER; - hash = (53 * hash) + getName().hashCode(); - hash = (37 * hash) + RESOURCE_ID_FIELD_NUMBER; - hash = (53 * hash) + getResourceId(); - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } - - public static org.tensorflow.proto.framework.Resource parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static org.tensorflow.proto.framework.Resource parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static org.tensorflow.proto.framework.Resource parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static org.tensorflow.proto.framework.Resource parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static org.tensorflow.proto.framework.Resource parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static org.tensorflow.proto.framework.Resource parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static org.tensorflow.proto.framework.Resource parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - public static org.tensorflow.proto.framework.Resource parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - public static org.tensorflow.proto.framework.Resource parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } - public static org.tensorflow.proto.framework.Resource parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } - public static org.tensorflow.proto.framework.Resource parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - public static org.tensorflow.proto.framework.Resource parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - - @java.lang.Override - public Builder newBuilderForType() { return newBuilder(); } - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } - public static Builder newBuilder(org.tensorflow.proto.framework.Resource prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } - - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } - /** - *
-   * A 'resource' generally is a specific computation component on a device. These
-   * can range from threads on CPUs to specific arithmetic units on hardware
-   * devices.
-   * 
- * - * Protobuf type {@code tensorflow.profiler.Resource} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:tensorflow.profiler.Resource) - org.tensorflow.proto.framework.ResourceOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Resource_descriptor; - } - - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Resource_fieldAccessorTable - .ensureFieldAccessorsInitialized( - org.tensorflow.proto.framework.Resource.class, org.tensorflow.proto.framework.Resource.Builder.class); - } - - // Construct using org.tensorflow.proto.framework.Resource.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } - - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } - private void maybeForceBuilderInitialization() { - if (com.google.protobuf.GeneratedMessageV3 - .alwaysUseFieldBuilders) { - } - } - @java.lang.Override - public Builder clear() { - super.clear(); - name_ = ""; - - resourceId_ = 0; - - return this; - } - - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Resource_descriptor; - } - - @java.lang.Override - public org.tensorflow.proto.framework.Resource getDefaultInstanceForType() { - return org.tensorflow.proto.framework.Resource.getDefaultInstance(); - } - - @java.lang.Override - public org.tensorflow.proto.framework.Resource build() { - org.tensorflow.proto.framework.Resource result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } - - @java.lang.Override - public org.tensorflow.proto.framework.Resource buildPartial() { - org.tensorflow.proto.framework.Resource result = new org.tensorflow.proto.framework.Resource(this); - result.name_ = name_; - result.resourceId_ = resourceId_; - onBuilt(); - return result; - } - - @java.lang.Override - public Builder clone() { - return super.clone(); - } - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); - } - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); - } - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); - } - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.tensorflow.proto.framework.Resource) { - return mergeFrom((org.tensorflow.proto.framework.Resource)other); - } else { - super.mergeFrom(other); - return this; - } - } - - public Builder mergeFrom(org.tensorflow.proto.framework.Resource other) { - if (other == org.tensorflow.proto.framework.Resource.getDefaultInstance()) return this; - if (!other.getName().isEmpty()) { - name_ = other.name_; - onChanged(); - } - if (other.getResourceId() != 0) { - setResourceId(other.getResourceId()); - } - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } - - @java.lang.Override - public final boolean isInitialized() { - return true; - } - - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - org.tensorflow.proto.framework.Resource parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (org.tensorflow.proto.framework.Resource) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } - - private java.lang.Object name_ = ""; - /** - *
-     * The name of the resource.
-     * 
- * - * string name = 1; - */ - public java.lang.String getName() { - java.lang.Object ref = name_; - if (!(ref instanceof java.lang.String)) { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - name_ = s; - return s; - } else { - return (java.lang.String) ref; - } - } - /** - *
-     * The name of the resource.
-     * 
- * - * string name = 1; - */ - public com.google.protobuf.ByteString - getNameBytes() { - java.lang.Object ref = name_; - if (ref instanceof String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - name_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } - } - /** - *
-     * The name of the resource.
-     * 
- * - * string name = 1; - */ - public Builder setName( - java.lang.String value) { - if (value == null) { - throw new NullPointerException(); - } - - name_ = value; - onChanged(); - return this; - } - /** - *
-     * The name of the resource.
-     * 
- * - * string name = 1; - */ - public Builder clearName() { - - name_ = getDefaultInstance().getName(); - onChanged(); - return this; - } - /** - *
-     * The name of the resource.
-     * 
- * - * string name = 1; - */ - public Builder setNameBytes( - com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - checkByteStringIsUtf8(value); - - name_ = value; - onChanged(); - return this; - } - - private int resourceId_ ; - /** - *
-     * The id of the resource. Unique within a device.
-     * 
- * - * uint32 resource_id = 2; - */ - public int getResourceId() { - return resourceId_; - } - /** - *
-     * The id of the resource. Unique within a device.
-     * 
- * - * uint32 resource_id = 2; - */ - public Builder setResourceId(int value) { - - resourceId_ = value; - onChanged(); - return this; - } - /** - *
-     * The id of the resource. Unique within a device.
-     * 
- * - * uint32 resource_id = 2; - */ - public Builder clearResourceId() { - - resourceId_ = 0; - onChanged(); - return this; - } - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } - - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } - - - // @@protoc_insertion_point(builder_scope:tensorflow.profiler.Resource) - } - - // @@protoc_insertion_point(class_scope:tensorflow.profiler.Resource) - private static final org.tensorflow.proto.framework.Resource DEFAULT_INSTANCE; - static { - DEFAULT_INSTANCE = new org.tensorflow.proto.framework.Resource(); - } - - public static org.tensorflow.proto.framework.Resource getDefaultInstance() { - return DEFAULT_INSTANCE; - } - - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public Resource parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new Resource(input, extensionRegistry); - } - }; - - public static com.google.protobuf.Parser parser() { - return PARSER; - } - - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } - - @java.lang.Override - public org.tensorflow.proto.framework.Resource getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } - -} - diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/ResourceOrBuilder.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/ResourceOrBuilder.java deleted file mode 100644 index f6bc0c0d190..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/ResourceOrBuilder.java +++ /dev/null @@ -1,36 +0,0 @@ -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: tensorflow/core/protobuf/trace_events.proto - -package org.tensorflow.proto.framework; - -public interface ResourceOrBuilder extends - // @@protoc_insertion_point(interface_extends:tensorflow.profiler.Resource) - com.google.protobuf.MessageOrBuilder { - - /** - *
-   * The name of the resource.
-   * 
- * - * string name = 1; - */ - java.lang.String getName(); - /** - *
-   * The name of the resource.
-   * 
- * - * string name = 1; - */ - com.google.protobuf.ByteString - getNameBytes(); - - /** - *
-   * The id of the resource. Unique within a device.
-   * 
- * - * uint32 resource_id = 2; - */ - int getResourceId(); -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/Trace.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/Trace.java deleted file mode 100644 index 09d97221dde..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/Trace.java +++ /dev/null @@ -1,1193 +0,0 @@ -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: tensorflow/core/protobuf/trace_events.proto - -package org.tensorflow.proto.framework; - -/** - *
- * A 'Trace' contains metadata for the individual traces of a system.
- * 
- * - * Protobuf type {@code tensorflow.profiler.Trace} - */ -public final class Trace extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:tensorflow.profiler.Trace) - TraceOrBuilder { -private static final long serialVersionUID = 0L; - // Use Trace.newBuilder() to construct. - private Trace(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } - private Trace() { - traceEvents_ = java.util.Collections.emptyList(); - } - - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new Trace(); - } - - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } - private Trace( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 10: { - if (!((mutable_bitField0_ & 0x00000001) != 0)) { - devices_ = com.google.protobuf.MapField.newMapField( - DevicesDefaultEntryHolder.defaultEntry); - mutable_bitField0_ |= 0x00000001; - } - com.google.protobuf.MapEntry - devices__ = input.readMessage( - DevicesDefaultEntryHolder.defaultEntry.getParserForType(), extensionRegistry); - devices_.getMutableMap().put( - devices__.getKey(), devices__.getValue()); - break; - } - case 34: { - if (!((mutable_bitField0_ & 0x00000002) != 0)) { - traceEvents_ = new java.util.ArrayList(); - mutable_bitField0_ |= 0x00000002; - } - traceEvents_.add( - input.readMessage(org.tensorflow.proto.framework.TraceEvent.parser(), extensionRegistry)); - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - if (((mutable_bitField0_ & 0x00000002) != 0)) { - traceEvents_ = java.util.Collections.unmodifiableList(traceEvents_); - } - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Trace_descriptor; - } - - @SuppressWarnings({"rawtypes"}) - @java.lang.Override - protected com.google.protobuf.MapField internalGetMapField( - int number) { - switch (number) { - case 1: - return internalGetDevices(); - default: - throw new RuntimeException( - "Invalid map field number: " + number); - } - } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Trace_fieldAccessorTable - .ensureFieldAccessorsInitialized( - org.tensorflow.proto.framework.Trace.class, org.tensorflow.proto.framework.Trace.Builder.class); - } - - public static final int DEVICES_FIELD_NUMBER = 1; - private static final class DevicesDefaultEntryHolder { - static final com.google.protobuf.MapEntry< - java.lang.Integer, org.tensorflow.proto.framework.Device> defaultEntry = - com.google.protobuf.MapEntry - .newDefaultInstance( - org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Trace_DevicesEntry_descriptor, - com.google.protobuf.WireFormat.FieldType.UINT32, - 0, - com.google.protobuf.WireFormat.FieldType.MESSAGE, - org.tensorflow.proto.framework.Device.getDefaultInstance()); - } - private com.google.protobuf.MapField< - java.lang.Integer, org.tensorflow.proto.framework.Device> devices_; - private com.google.protobuf.MapField - internalGetDevices() { - if (devices_ == null) { - return com.google.protobuf.MapField.emptyMapField( - DevicesDefaultEntryHolder.defaultEntry); - } - return devices_; - } - - public int getDevicesCount() { - return internalGetDevices().getMap().size(); - } - /** - *
-   * The devices that this trace has information about. Maps from device_id to
-   * more data about the specific device.
-   * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - - public boolean containsDevices( - int key) { - - return internalGetDevices().getMap().containsKey(key); - } - /** - * Use {@link #getDevicesMap()} instead. - */ - @java.lang.Deprecated - public java.util.Map getDevices() { - return getDevicesMap(); - } - /** - *
-   * The devices that this trace has information about. Maps from device_id to
-   * more data about the specific device.
-   * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - - public java.util.Map getDevicesMap() { - return internalGetDevices().getMap(); - } - /** - *
-   * The devices that this trace has information about. Maps from device_id to
-   * more data about the specific device.
-   * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - - public org.tensorflow.proto.framework.Device getDevicesOrDefault( - int key, - org.tensorflow.proto.framework.Device defaultValue) { - - java.util.Map map = - internalGetDevices().getMap(); - return map.containsKey(key) ? map.get(key) : defaultValue; - } - /** - *
-   * The devices that this trace has information about. Maps from device_id to
-   * more data about the specific device.
-   * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - - public org.tensorflow.proto.framework.Device getDevicesOrThrow( - int key) { - - java.util.Map map = - internalGetDevices().getMap(); - if (!map.containsKey(key)) { - throw new java.lang.IllegalArgumentException(); - } - return map.get(key); - } - - public static final int TRACE_EVENTS_FIELD_NUMBER = 4; - private java.util.List traceEvents_; - /** - *
-   * All trace events capturing in the profiling period.
-   * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public java.util.List getTraceEventsList() { - return traceEvents_; - } - /** - *
-   * All trace events capturing in the profiling period.
-   * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public java.util.List - getTraceEventsOrBuilderList() { - return traceEvents_; - } - /** - *
-   * All trace events capturing in the profiling period.
-   * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public int getTraceEventsCount() { - return traceEvents_.size(); - } - /** - *
-   * All trace events capturing in the profiling period.
-   * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public org.tensorflow.proto.framework.TraceEvent getTraceEvents(int index) { - return traceEvents_.get(index); - } - /** - *
-   * All trace events capturing in the profiling period.
-   * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public org.tensorflow.proto.framework.TraceEventOrBuilder getTraceEventsOrBuilder( - int index) { - return traceEvents_.get(index); - } - - private byte memoizedIsInitialized = -1; - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; - - memoizedIsInitialized = 1; - return true; - } - - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - com.google.protobuf.GeneratedMessageV3 - .serializeIntegerMapTo( - output, - internalGetDevices(), - DevicesDefaultEntryHolder.defaultEntry, - 1); - for (int i = 0; i < traceEvents_.size(); i++) { - output.writeMessage(4, traceEvents_.get(i)); - } - unknownFields.writeTo(output); - } - - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; - - size = 0; - for (java.util.Map.Entry entry - : internalGetDevices().getMap().entrySet()) { - com.google.protobuf.MapEntry - devices__ = DevicesDefaultEntryHolder.defaultEntry.newBuilderForType() - .setKey(entry.getKey()) - .setValue(entry.getValue()) - .build(); - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(1, devices__); - } - for (int i = 0; i < traceEvents_.size(); i++) { - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(4, traceEvents_.get(i)); - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } - - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof org.tensorflow.proto.framework.Trace)) { - return super.equals(obj); - } - org.tensorflow.proto.framework.Trace other = (org.tensorflow.proto.framework.Trace) obj; - - if (!internalGetDevices().equals( - other.internalGetDevices())) return false; - if (!getTraceEventsList() - .equals(other.getTraceEventsList())) return false; - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } - - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - if (!internalGetDevices().getMap().isEmpty()) { - hash = (37 * hash) + DEVICES_FIELD_NUMBER; - hash = (53 * hash) + internalGetDevices().hashCode(); - } - if (getTraceEventsCount() > 0) { - hash = (37 * hash) + TRACE_EVENTS_FIELD_NUMBER; - hash = (53 * hash) + getTraceEventsList().hashCode(); - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } - - public static org.tensorflow.proto.framework.Trace parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static org.tensorflow.proto.framework.Trace parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static org.tensorflow.proto.framework.Trace parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static org.tensorflow.proto.framework.Trace parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static org.tensorflow.proto.framework.Trace parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static org.tensorflow.proto.framework.Trace parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static org.tensorflow.proto.framework.Trace parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - public static org.tensorflow.proto.framework.Trace parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - public static org.tensorflow.proto.framework.Trace parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } - public static org.tensorflow.proto.framework.Trace parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } - public static org.tensorflow.proto.framework.Trace parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - public static org.tensorflow.proto.framework.Trace parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - - @java.lang.Override - public Builder newBuilderForType() { return newBuilder(); } - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } - public static Builder newBuilder(org.tensorflow.proto.framework.Trace prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } - - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } - /** - *
-   * A 'Trace' contains metadata for the individual traces of a system.
-   * 
- * - * Protobuf type {@code tensorflow.profiler.Trace} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:tensorflow.profiler.Trace) - org.tensorflow.proto.framework.TraceOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Trace_descriptor; - } - - @SuppressWarnings({"rawtypes"}) - protected com.google.protobuf.MapField internalGetMapField( - int number) { - switch (number) { - case 1: - return internalGetDevices(); - default: - throw new RuntimeException( - "Invalid map field number: " + number); - } - } - @SuppressWarnings({"rawtypes"}) - protected com.google.protobuf.MapField internalGetMutableMapField( - int number) { - switch (number) { - case 1: - return internalGetMutableDevices(); - default: - throw new RuntimeException( - "Invalid map field number: " + number); - } - } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Trace_fieldAccessorTable - .ensureFieldAccessorsInitialized( - org.tensorflow.proto.framework.Trace.class, org.tensorflow.proto.framework.Trace.Builder.class); - } - - // Construct using org.tensorflow.proto.framework.Trace.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } - - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } - private void maybeForceBuilderInitialization() { - if (com.google.protobuf.GeneratedMessageV3 - .alwaysUseFieldBuilders) { - getTraceEventsFieldBuilder(); - } - } - @java.lang.Override - public Builder clear() { - super.clear(); - internalGetMutableDevices().clear(); - if (traceEventsBuilder_ == null) { - traceEvents_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000002); - } else { - traceEventsBuilder_.clear(); - } - return this; - } - - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_Trace_descriptor; - } - - @java.lang.Override - public org.tensorflow.proto.framework.Trace getDefaultInstanceForType() { - return org.tensorflow.proto.framework.Trace.getDefaultInstance(); - } - - @java.lang.Override - public org.tensorflow.proto.framework.Trace build() { - org.tensorflow.proto.framework.Trace result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } - - @java.lang.Override - public org.tensorflow.proto.framework.Trace buildPartial() { - org.tensorflow.proto.framework.Trace result = new org.tensorflow.proto.framework.Trace(this); - int from_bitField0_ = bitField0_; - result.devices_ = internalGetDevices(); - result.devices_.makeImmutable(); - if (traceEventsBuilder_ == null) { - if (((bitField0_ & 0x00000002) != 0)) { - traceEvents_ = java.util.Collections.unmodifiableList(traceEvents_); - bitField0_ = (bitField0_ & ~0x00000002); - } - result.traceEvents_ = traceEvents_; - } else { - result.traceEvents_ = traceEventsBuilder_.build(); - } - onBuilt(); - return result; - } - - @java.lang.Override - public Builder clone() { - return super.clone(); - } - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); - } - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); - } - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); - } - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.tensorflow.proto.framework.Trace) { - return mergeFrom((org.tensorflow.proto.framework.Trace)other); - } else { - super.mergeFrom(other); - return this; - } - } - - public Builder mergeFrom(org.tensorflow.proto.framework.Trace other) { - if (other == org.tensorflow.proto.framework.Trace.getDefaultInstance()) return this; - internalGetMutableDevices().mergeFrom( - other.internalGetDevices()); - if (traceEventsBuilder_ == null) { - if (!other.traceEvents_.isEmpty()) { - if (traceEvents_.isEmpty()) { - traceEvents_ = other.traceEvents_; - bitField0_ = (bitField0_ & ~0x00000002); - } else { - ensureTraceEventsIsMutable(); - traceEvents_.addAll(other.traceEvents_); - } - onChanged(); - } - } else { - if (!other.traceEvents_.isEmpty()) { - if (traceEventsBuilder_.isEmpty()) { - traceEventsBuilder_.dispose(); - traceEventsBuilder_ = null; - traceEvents_ = other.traceEvents_; - bitField0_ = (bitField0_ & ~0x00000002); - traceEventsBuilder_ = - com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders ? - getTraceEventsFieldBuilder() : null; - } else { - traceEventsBuilder_.addAllMessages(other.traceEvents_); - } - } - } - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } - - @java.lang.Override - public final boolean isInitialized() { - return true; - } - - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - org.tensorflow.proto.framework.Trace parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (org.tensorflow.proto.framework.Trace) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } - private int bitField0_; - - private com.google.protobuf.MapField< - java.lang.Integer, org.tensorflow.proto.framework.Device> devices_; - private com.google.protobuf.MapField - internalGetDevices() { - if (devices_ == null) { - return com.google.protobuf.MapField.emptyMapField( - DevicesDefaultEntryHolder.defaultEntry); - } - return devices_; - } - private com.google.protobuf.MapField - internalGetMutableDevices() { - onChanged();; - if (devices_ == null) { - devices_ = com.google.protobuf.MapField.newMapField( - DevicesDefaultEntryHolder.defaultEntry); - } - if (!devices_.isMutable()) { - devices_ = devices_.copy(); - } - return devices_; - } - - public int getDevicesCount() { - return internalGetDevices().getMap().size(); - } - /** - *
-     * The devices that this trace has information about. Maps from device_id to
-     * more data about the specific device.
-     * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - - public boolean containsDevices( - int key) { - - return internalGetDevices().getMap().containsKey(key); - } - /** - * Use {@link #getDevicesMap()} instead. - */ - @java.lang.Deprecated - public java.util.Map getDevices() { - return getDevicesMap(); - } - /** - *
-     * The devices that this trace has information about. Maps from device_id to
-     * more data about the specific device.
-     * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - - public java.util.Map getDevicesMap() { - return internalGetDevices().getMap(); - } - /** - *
-     * The devices that this trace has information about. Maps from device_id to
-     * more data about the specific device.
-     * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - - public org.tensorflow.proto.framework.Device getDevicesOrDefault( - int key, - org.tensorflow.proto.framework.Device defaultValue) { - - java.util.Map map = - internalGetDevices().getMap(); - return map.containsKey(key) ? map.get(key) : defaultValue; - } - /** - *
-     * The devices that this trace has information about. Maps from device_id to
-     * more data about the specific device.
-     * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - - public org.tensorflow.proto.framework.Device getDevicesOrThrow( - int key) { - - java.util.Map map = - internalGetDevices().getMap(); - if (!map.containsKey(key)) { - throw new java.lang.IllegalArgumentException(); - } - return map.get(key); - } - - public Builder clearDevices() { - internalGetMutableDevices().getMutableMap() - .clear(); - return this; - } - /** - *
-     * The devices that this trace has information about. Maps from device_id to
-     * more data about the specific device.
-     * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - - public Builder removeDevices( - int key) { - - internalGetMutableDevices().getMutableMap() - .remove(key); - return this; - } - /** - * Use alternate mutation accessors instead. - */ - @java.lang.Deprecated - public java.util.Map - getMutableDevices() { - return internalGetMutableDevices().getMutableMap(); - } - /** - *
-     * The devices that this trace has information about. Maps from device_id to
-     * more data about the specific device.
-     * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - public Builder putDevices( - int key, - org.tensorflow.proto.framework.Device value) { - - if (value == null) { throw new java.lang.NullPointerException(); } - internalGetMutableDevices().getMutableMap() - .put(key, value); - return this; - } - /** - *
-     * The devices that this trace has information about. Maps from device_id to
-     * more data about the specific device.
-     * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - - public Builder putAllDevices( - java.util.Map values) { - internalGetMutableDevices().getMutableMap() - .putAll(values); - return this; - } - - private java.util.List traceEvents_ = - java.util.Collections.emptyList(); - private void ensureTraceEventsIsMutable() { - if (!((bitField0_ & 0x00000002) != 0)) { - traceEvents_ = new java.util.ArrayList(traceEvents_); - bitField0_ |= 0x00000002; - } - } - - private com.google.protobuf.RepeatedFieldBuilderV3< - org.tensorflow.proto.framework.TraceEvent, org.tensorflow.proto.framework.TraceEvent.Builder, org.tensorflow.proto.framework.TraceEventOrBuilder> traceEventsBuilder_; - - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public java.util.List getTraceEventsList() { - if (traceEventsBuilder_ == null) { - return java.util.Collections.unmodifiableList(traceEvents_); - } else { - return traceEventsBuilder_.getMessageList(); - } - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public int getTraceEventsCount() { - if (traceEventsBuilder_ == null) { - return traceEvents_.size(); - } else { - return traceEventsBuilder_.getCount(); - } - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public org.tensorflow.proto.framework.TraceEvent getTraceEvents(int index) { - if (traceEventsBuilder_ == null) { - return traceEvents_.get(index); - } else { - return traceEventsBuilder_.getMessage(index); - } - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public Builder setTraceEvents( - int index, org.tensorflow.proto.framework.TraceEvent value) { - if (traceEventsBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - ensureTraceEventsIsMutable(); - traceEvents_.set(index, value); - onChanged(); - } else { - traceEventsBuilder_.setMessage(index, value); - } - return this; - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public Builder setTraceEvents( - int index, org.tensorflow.proto.framework.TraceEvent.Builder builderForValue) { - if (traceEventsBuilder_ == null) { - ensureTraceEventsIsMutable(); - traceEvents_.set(index, builderForValue.build()); - onChanged(); - } else { - traceEventsBuilder_.setMessage(index, builderForValue.build()); - } - return this; - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public Builder addTraceEvents(org.tensorflow.proto.framework.TraceEvent value) { - if (traceEventsBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - ensureTraceEventsIsMutable(); - traceEvents_.add(value); - onChanged(); - } else { - traceEventsBuilder_.addMessage(value); - } - return this; - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public Builder addTraceEvents( - int index, org.tensorflow.proto.framework.TraceEvent value) { - if (traceEventsBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - ensureTraceEventsIsMutable(); - traceEvents_.add(index, value); - onChanged(); - } else { - traceEventsBuilder_.addMessage(index, value); - } - return this; - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public Builder addTraceEvents( - org.tensorflow.proto.framework.TraceEvent.Builder builderForValue) { - if (traceEventsBuilder_ == null) { - ensureTraceEventsIsMutable(); - traceEvents_.add(builderForValue.build()); - onChanged(); - } else { - traceEventsBuilder_.addMessage(builderForValue.build()); - } - return this; - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public Builder addTraceEvents( - int index, org.tensorflow.proto.framework.TraceEvent.Builder builderForValue) { - if (traceEventsBuilder_ == null) { - ensureTraceEventsIsMutable(); - traceEvents_.add(index, builderForValue.build()); - onChanged(); - } else { - traceEventsBuilder_.addMessage(index, builderForValue.build()); - } - return this; - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public Builder addAllTraceEvents( - java.lang.Iterable values) { - if (traceEventsBuilder_ == null) { - ensureTraceEventsIsMutable(); - com.google.protobuf.AbstractMessageLite.Builder.addAll( - values, traceEvents_); - onChanged(); - } else { - traceEventsBuilder_.addAllMessages(values); - } - return this; - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public Builder clearTraceEvents() { - if (traceEventsBuilder_ == null) { - traceEvents_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000002); - onChanged(); - } else { - traceEventsBuilder_.clear(); - } - return this; - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public Builder removeTraceEvents(int index) { - if (traceEventsBuilder_ == null) { - ensureTraceEventsIsMutable(); - traceEvents_.remove(index); - onChanged(); - } else { - traceEventsBuilder_.remove(index); - } - return this; - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public org.tensorflow.proto.framework.TraceEvent.Builder getTraceEventsBuilder( - int index) { - return getTraceEventsFieldBuilder().getBuilder(index); - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public org.tensorflow.proto.framework.TraceEventOrBuilder getTraceEventsOrBuilder( - int index) { - if (traceEventsBuilder_ == null) { - return traceEvents_.get(index); } else { - return traceEventsBuilder_.getMessageOrBuilder(index); - } - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public java.util.List - getTraceEventsOrBuilderList() { - if (traceEventsBuilder_ != null) { - return traceEventsBuilder_.getMessageOrBuilderList(); - } else { - return java.util.Collections.unmodifiableList(traceEvents_); - } - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public org.tensorflow.proto.framework.TraceEvent.Builder addTraceEventsBuilder() { - return getTraceEventsFieldBuilder().addBuilder( - org.tensorflow.proto.framework.TraceEvent.getDefaultInstance()); - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public org.tensorflow.proto.framework.TraceEvent.Builder addTraceEventsBuilder( - int index) { - return getTraceEventsFieldBuilder().addBuilder( - index, org.tensorflow.proto.framework.TraceEvent.getDefaultInstance()); - } - /** - *
-     * All trace events capturing in the profiling period.
-     * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - public java.util.List - getTraceEventsBuilderList() { - return getTraceEventsFieldBuilder().getBuilderList(); - } - private com.google.protobuf.RepeatedFieldBuilderV3< - org.tensorflow.proto.framework.TraceEvent, org.tensorflow.proto.framework.TraceEvent.Builder, org.tensorflow.proto.framework.TraceEventOrBuilder> - getTraceEventsFieldBuilder() { - if (traceEventsBuilder_ == null) { - traceEventsBuilder_ = new com.google.protobuf.RepeatedFieldBuilderV3< - org.tensorflow.proto.framework.TraceEvent, org.tensorflow.proto.framework.TraceEvent.Builder, org.tensorflow.proto.framework.TraceEventOrBuilder>( - traceEvents_, - ((bitField0_ & 0x00000002) != 0), - getParentForChildren(), - isClean()); - traceEvents_ = null; - } - return traceEventsBuilder_; - } - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } - - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } - - - // @@protoc_insertion_point(builder_scope:tensorflow.profiler.Trace) - } - - // @@protoc_insertion_point(class_scope:tensorflow.profiler.Trace) - private static final org.tensorflow.proto.framework.Trace DEFAULT_INSTANCE; - static { - DEFAULT_INSTANCE = new org.tensorflow.proto.framework.Trace(); - } - - public static org.tensorflow.proto.framework.Trace getDefaultInstance() { - return DEFAULT_INSTANCE; - } - - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public Trace parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new Trace(input, extensionRegistry); - } - }; - - public static com.google.protobuf.Parser parser() { - return PARSER; - } - - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } - - @java.lang.Override - public org.tensorflow.proto.framework.Trace getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } - -} - diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/TraceEvent.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/TraceEvent.java deleted file mode 100644 index 8f34acfde34..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/TraceEvent.java +++ /dev/null @@ -1,1208 +0,0 @@ -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: tensorflow/core/protobuf/trace_events.proto - -package org.tensorflow.proto.framework; - -/** - * Protobuf type {@code tensorflow.profiler.TraceEvent} - */ -public final class TraceEvent extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:tensorflow.profiler.TraceEvent) - TraceEventOrBuilder { -private static final long serialVersionUID = 0L; - // Use TraceEvent.newBuilder() to construct. - private TraceEvent(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } - private TraceEvent() { - name_ = ""; - } - - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new TraceEvent(); - } - - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } - private TraceEvent( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 8: { - - deviceId_ = input.readUInt32(); - break; - } - case 16: { - - resourceId_ = input.readUInt32(); - break; - } - case 26: { - java.lang.String s = input.readStringRequireUtf8(); - - name_ = s; - break; - } - case 72: { - - timestampPs_ = input.readUInt64(); - break; - } - case 80: { - - durationPs_ = input.readUInt64(); - break; - } - case 90: { - if (!((mutable_bitField0_ & 0x00000001) != 0)) { - args_ = com.google.protobuf.MapField.newMapField( - ArgsDefaultEntryHolder.defaultEntry); - mutable_bitField0_ |= 0x00000001; - } - com.google.protobuf.MapEntry - args__ = input.readMessage( - ArgsDefaultEntryHolder.defaultEntry.getParserForType(), extensionRegistry); - args_.getMutableMap().put( - args__.getKey(), args__.getValue()); - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_TraceEvent_descriptor; - } - - @SuppressWarnings({"rawtypes"}) - @java.lang.Override - protected com.google.protobuf.MapField internalGetMapField( - int number) { - switch (number) { - case 11: - return internalGetArgs(); - default: - throw new RuntimeException( - "Invalid map field number: " + number); - } - } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_TraceEvent_fieldAccessorTable - .ensureFieldAccessorsInitialized( - org.tensorflow.proto.framework.TraceEvent.class, org.tensorflow.proto.framework.TraceEvent.Builder.class); - } - - public static final int DEVICE_ID_FIELD_NUMBER = 1; - private int deviceId_; - /** - *
-   * The id of the device that this event occurred on. The full dataset should
-   * have this device present in the Trace object.
-   * 
- * - * uint32 device_id = 1; - */ - public int getDeviceId() { - return deviceId_; - } - - public static final int RESOURCE_ID_FIELD_NUMBER = 2; - private int resourceId_; - /** - *
-   * The id of the resource that this event occurred on. The full dataset should
-   * have this resource present in the Device object of the Trace object. A
-   * resource_id is unique on a specific device, but not necessarily within the
-   * trace.
-   * 
- * - * uint32 resource_id = 2; - */ - public int getResourceId() { - return resourceId_; - } - - public static final int NAME_FIELD_NUMBER = 3; - private volatile java.lang.Object name_; - /** - *
-   * The name of this trace event.
-   * 
- * - * string name = 3; - */ - public java.lang.String getName() { - java.lang.Object ref = name_; - if (ref instanceof java.lang.String) { - return (java.lang.String) ref; - } else { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - name_ = s; - return s; - } - } - /** - *
-   * The name of this trace event.
-   * 
- * - * string name = 3; - */ - public com.google.protobuf.ByteString - getNameBytes() { - java.lang.Object ref = name_; - if (ref instanceof java.lang.String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - name_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } - } - - public static final int TIMESTAMP_PS_FIELD_NUMBER = 9; - private long timestampPs_; - /** - *
-   * The timestamp that this event occurred at (in picos since tracing started).
-   * 
- * - * uint64 timestamp_ps = 9; - */ - public long getTimestampPs() { - return timestampPs_; - } - - public static final int DURATION_PS_FIELD_NUMBER = 10; - private long durationPs_; - /** - *
-   * The duration of the event in picoseconds if applicable.
-   * Events without duration are called instant events.
-   * 
- * - * uint64 duration_ps = 10; - */ - public long getDurationPs() { - return durationPs_; - } - - public static final int ARGS_FIELD_NUMBER = 11; - private static final class ArgsDefaultEntryHolder { - static final com.google.protobuf.MapEntry< - java.lang.String, java.lang.String> defaultEntry = - com.google.protobuf.MapEntry - .newDefaultInstance( - org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_TraceEvent_ArgsEntry_descriptor, - com.google.protobuf.WireFormat.FieldType.STRING, - "", - com.google.protobuf.WireFormat.FieldType.STRING, - ""); - } - private com.google.protobuf.MapField< - java.lang.String, java.lang.String> args_; - private com.google.protobuf.MapField - internalGetArgs() { - if (args_ == null) { - return com.google.protobuf.MapField.emptyMapField( - ArgsDefaultEntryHolder.defaultEntry); - } - return args_; - } - - public int getArgsCount() { - return internalGetArgs().getMap().size(); - } - /** - *
-   * Extra arguments that will be displayed in trace view.
-   * 
- * - * map<string, string> args = 11; - */ - - public boolean containsArgs( - java.lang.String key) { - if (key == null) { throw new java.lang.NullPointerException(); } - return internalGetArgs().getMap().containsKey(key); - } - /** - * Use {@link #getArgsMap()} instead. - */ - @java.lang.Deprecated - public java.util.Map getArgs() { - return getArgsMap(); - } - /** - *
-   * Extra arguments that will be displayed in trace view.
-   * 
- * - * map<string, string> args = 11; - */ - - public java.util.Map getArgsMap() { - return internalGetArgs().getMap(); - } - /** - *
-   * Extra arguments that will be displayed in trace view.
-   * 
- * - * map<string, string> args = 11; - */ - - public java.lang.String getArgsOrDefault( - java.lang.String key, - java.lang.String defaultValue) { - if (key == null) { throw new java.lang.NullPointerException(); } - java.util.Map map = - internalGetArgs().getMap(); - return map.containsKey(key) ? map.get(key) : defaultValue; - } - /** - *
-   * Extra arguments that will be displayed in trace view.
-   * 
- * - * map<string, string> args = 11; - */ - - public java.lang.String getArgsOrThrow( - java.lang.String key) { - if (key == null) { throw new java.lang.NullPointerException(); } - java.util.Map map = - internalGetArgs().getMap(); - if (!map.containsKey(key)) { - throw new java.lang.IllegalArgumentException(); - } - return map.get(key); - } - - private byte memoizedIsInitialized = -1; - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; - - memoizedIsInitialized = 1; - return true; - } - - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - if (deviceId_ != 0) { - output.writeUInt32(1, deviceId_); - } - if (resourceId_ != 0) { - output.writeUInt32(2, resourceId_); - } - if (!getNameBytes().isEmpty()) { - com.google.protobuf.GeneratedMessageV3.writeString(output, 3, name_); - } - if (timestampPs_ != 0L) { - output.writeUInt64(9, timestampPs_); - } - if (durationPs_ != 0L) { - output.writeUInt64(10, durationPs_); - } - com.google.protobuf.GeneratedMessageV3 - .serializeStringMapTo( - output, - internalGetArgs(), - ArgsDefaultEntryHolder.defaultEntry, - 11); - unknownFields.writeTo(output); - } - - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; - - size = 0; - if (deviceId_ != 0) { - size += com.google.protobuf.CodedOutputStream - .computeUInt32Size(1, deviceId_); - } - if (resourceId_ != 0) { - size += com.google.protobuf.CodedOutputStream - .computeUInt32Size(2, resourceId_); - } - if (!getNameBytes().isEmpty()) { - size += com.google.protobuf.GeneratedMessageV3.computeStringSize(3, name_); - } - if (timestampPs_ != 0L) { - size += com.google.protobuf.CodedOutputStream - .computeUInt64Size(9, timestampPs_); - } - if (durationPs_ != 0L) { - size += com.google.protobuf.CodedOutputStream - .computeUInt64Size(10, durationPs_); - } - for (java.util.Map.Entry entry - : internalGetArgs().getMap().entrySet()) { - com.google.protobuf.MapEntry - args__ = ArgsDefaultEntryHolder.defaultEntry.newBuilderForType() - .setKey(entry.getKey()) - .setValue(entry.getValue()) - .build(); - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(11, args__); - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } - - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof org.tensorflow.proto.framework.TraceEvent)) { - return super.equals(obj); - } - org.tensorflow.proto.framework.TraceEvent other = (org.tensorflow.proto.framework.TraceEvent) obj; - - if (getDeviceId() - != other.getDeviceId()) return false; - if (getResourceId() - != other.getResourceId()) return false; - if (!getName() - .equals(other.getName())) return false; - if (getTimestampPs() - != other.getTimestampPs()) return false; - if (getDurationPs() - != other.getDurationPs()) return false; - if (!internalGetArgs().equals( - other.internalGetArgs())) return false; - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } - - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - hash = (37 * hash) + DEVICE_ID_FIELD_NUMBER; - hash = (53 * hash) + getDeviceId(); - hash = (37 * hash) + RESOURCE_ID_FIELD_NUMBER; - hash = (53 * hash) + getResourceId(); - hash = (37 * hash) + NAME_FIELD_NUMBER; - hash = (53 * hash) + getName().hashCode(); - hash = (37 * hash) + TIMESTAMP_PS_FIELD_NUMBER; - hash = (53 * hash) + com.google.protobuf.Internal.hashLong( - getTimestampPs()); - hash = (37 * hash) + DURATION_PS_FIELD_NUMBER; - hash = (53 * hash) + com.google.protobuf.Internal.hashLong( - getDurationPs()); - if (!internalGetArgs().getMap().isEmpty()) { - hash = (37 * hash) + ARGS_FIELD_NUMBER; - hash = (53 * hash) + internalGetArgs().hashCode(); - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } - - public static org.tensorflow.proto.framework.TraceEvent parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static org.tensorflow.proto.framework.TraceEvent parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static org.tensorflow.proto.framework.TraceEvent parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static org.tensorflow.proto.framework.TraceEvent parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static org.tensorflow.proto.framework.TraceEvent parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static org.tensorflow.proto.framework.TraceEvent parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static org.tensorflow.proto.framework.TraceEvent parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - public static org.tensorflow.proto.framework.TraceEvent parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - public static org.tensorflow.proto.framework.TraceEvent parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } - public static org.tensorflow.proto.framework.TraceEvent parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } - public static org.tensorflow.proto.framework.TraceEvent parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - public static org.tensorflow.proto.framework.TraceEvent parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - - @java.lang.Override - public Builder newBuilderForType() { return newBuilder(); } - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } - public static Builder newBuilder(org.tensorflow.proto.framework.TraceEvent prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } - - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } - /** - * Protobuf type {@code tensorflow.profiler.TraceEvent} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:tensorflow.profiler.TraceEvent) - org.tensorflow.proto.framework.TraceEventOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_TraceEvent_descriptor; - } - - @SuppressWarnings({"rawtypes"}) - protected com.google.protobuf.MapField internalGetMapField( - int number) { - switch (number) { - case 11: - return internalGetArgs(); - default: - throw new RuntimeException( - "Invalid map field number: " + number); - } - } - @SuppressWarnings({"rawtypes"}) - protected com.google.protobuf.MapField internalGetMutableMapField( - int number) { - switch (number) { - case 11: - return internalGetMutableArgs(); - default: - throw new RuntimeException( - "Invalid map field number: " + number); - } - } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_TraceEvent_fieldAccessorTable - .ensureFieldAccessorsInitialized( - org.tensorflow.proto.framework.TraceEvent.class, org.tensorflow.proto.framework.TraceEvent.Builder.class); - } - - // Construct using org.tensorflow.proto.framework.TraceEvent.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } - - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } - private void maybeForceBuilderInitialization() { - if (com.google.protobuf.GeneratedMessageV3 - .alwaysUseFieldBuilders) { - } - } - @java.lang.Override - public Builder clear() { - super.clear(); - deviceId_ = 0; - - resourceId_ = 0; - - name_ = ""; - - timestampPs_ = 0L; - - durationPs_ = 0L; - - internalGetMutableArgs().clear(); - return this; - } - - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return org.tensorflow.proto.framework.TraceEventsProtos.internal_static_tensorflow_profiler_TraceEvent_descriptor; - } - - @java.lang.Override - public org.tensorflow.proto.framework.TraceEvent getDefaultInstanceForType() { - return org.tensorflow.proto.framework.TraceEvent.getDefaultInstance(); - } - - @java.lang.Override - public org.tensorflow.proto.framework.TraceEvent build() { - org.tensorflow.proto.framework.TraceEvent result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } - - @java.lang.Override - public org.tensorflow.proto.framework.TraceEvent buildPartial() { - org.tensorflow.proto.framework.TraceEvent result = new org.tensorflow.proto.framework.TraceEvent(this); - int from_bitField0_ = bitField0_; - result.deviceId_ = deviceId_; - result.resourceId_ = resourceId_; - result.name_ = name_; - result.timestampPs_ = timestampPs_; - result.durationPs_ = durationPs_; - result.args_ = internalGetArgs(); - result.args_.makeImmutable(); - onBuilt(); - return result; - } - - @java.lang.Override - public Builder clone() { - return super.clone(); - } - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); - } - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); - } - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); - } - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.tensorflow.proto.framework.TraceEvent) { - return mergeFrom((org.tensorflow.proto.framework.TraceEvent)other); - } else { - super.mergeFrom(other); - return this; - } - } - - public Builder mergeFrom(org.tensorflow.proto.framework.TraceEvent other) { - if (other == org.tensorflow.proto.framework.TraceEvent.getDefaultInstance()) return this; - if (other.getDeviceId() != 0) { - setDeviceId(other.getDeviceId()); - } - if (other.getResourceId() != 0) { - setResourceId(other.getResourceId()); - } - if (!other.getName().isEmpty()) { - name_ = other.name_; - onChanged(); - } - if (other.getTimestampPs() != 0L) { - setTimestampPs(other.getTimestampPs()); - } - if (other.getDurationPs() != 0L) { - setDurationPs(other.getDurationPs()); - } - internalGetMutableArgs().mergeFrom( - other.internalGetArgs()); - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } - - @java.lang.Override - public final boolean isInitialized() { - return true; - } - - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - org.tensorflow.proto.framework.TraceEvent parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (org.tensorflow.proto.framework.TraceEvent) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } - private int bitField0_; - - private int deviceId_ ; - /** - *
-     * The id of the device that this event occurred on. The full dataset should
-     * have this device present in the Trace object.
-     * 
- * - * uint32 device_id = 1; - */ - public int getDeviceId() { - return deviceId_; - } - /** - *
-     * The id of the device that this event occurred on. The full dataset should
-     * have this device present in the Trace object.
-     * 
- * - * uint32 device_id = 1; - */ - public Builder setDeviceId(int value) { - - deviceId_ = value; - onChanged(); - return this; - } - /** - *
-     * The id of the device that this event occurred on. The full dataset should
-     * have this device present in the Trace object.
-     * 
- * - * uint32 device_id = 1; - */ - public Builder clearDeviceId() { - - deviceId_ = 0; - onChanged(); - return this; - } - - private int resourceId_ ; - /** - *
-     * The id of the resource that this event occurred on. The full dataset should
-     * have this resource present in the Device object of the Trace object. A
-     * resource_id is unique on a specific device, but not necessarily within the
-     * trace.
-     * 
- * - * uint32 resource_id = 2; - */ - public int getResourceId() { - return resourceId_; - } - /** - *
-     * The id of the resource that this event occurred on. The full dataset should
-     * have this resource present in the Device object of the Trace object. A
-     * resource_id is unique on a specific device, but not necessarily within the
-     * trace.
-     * 
- * - * uint32 resource_id = 2; - */ - public Builder setResourceId(int value) { - - resourceId_ = value; - onChanged(); - return this; - } - /** - *
-     * The id of the resource that this event occurred on. The full dataset should
-     * have this resource present in the Device object of the Trace object. A
-     * resource_id is unique on a specific device, but not necessarily within the
-     * trace.
-     * 
- * - * uint32 resource_id = 2; - */ - public Builder clearResourceId() { - - resourceId_ = 0; - onChanged(); - return this; - } - - private java.lang.Object name_ = ""; - /** - *
-     * The name of this trace event.
-     * 
- * - * string name = 3; - */ - public java.lang.String getName() { - java.lang.Object ref = name_; - if (!(ref instanceof java.lang.String)) { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - name_ = s; - return s; - } else { - return (java.lang.String) ref; - } - } - /** - *
-     * The name of this trace event.
-     * 
- * - * string name = 3; - */ - public com.google.protobuf.ByteString - getNameBytes() { - java.lang.Object ref = name_; - if (ref instanceof String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - name_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } - } - /** - *
-     * The name of this trace event.
-     * 
- * - * string name = 3; - */ - public Builder setName( - java.lang.String value) { - if (value == null) { - throw new NullPointerException(); - } - - name_ = value; - onChanged(); - return this; - } - /** - *
-     * The name of this trace event.
-     * 
- * - * string name = 3; - */ - public Builder clearName() { - - name_ = getDefaultInstance().getName(); - onChanged(); - return this; - } - /** - *
-     * The name of this trace event.
-     * 
- * - * string name = 3; - */ - public Builder setNameBytes( - com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - checkByteStringIsUtf8(value); - - name_ = value; - onChanged(); - return this; - } - - private long timestampPs_ ; - /** - *
-     * The timestamp that this event occurred at (in picos since tracing started).
-     * 
- * - * uint64 timestamp_ps = 9; - */ - public long getTimestampPs() { - return timestampPs_; - } - /** - *
-     * The timestamp that this event occurred at (in picos since tracing started).
-     * 
- * - * uint64 timestamp_ps = 9; - */ - public Builder setTimestampPs(long value) { - - timestampPs_ = value; - onChanged(); - return this; - } - /** - *
-     * The timestamp that this event occurred at (in picos since tracing started).
-     * 
- * - * uint64 timestamp_ps = 9; - */ - public Builder clearTimestampPs() { - - timestampPs_ = 0L; - onChanged(); - return this; - } - - private long durationPs_ ; - /** - *
-     * The duration of the event in picoseconds if applicable.
-     * Events without duration are called instant events.
-     * 
- * - * uint64 duration_ps = 10; - */ - public long getDurationPs() { - return durationPs_; - } - /** - *
-     * The duration of the event in picoseconds if applicable.
-     * Events without duration are called instant events.
-     * 
- * - * uint64 duration_ps = 10; - */ - public Builder setDurationPs(long value) { - - durationPs_ = value; - onChanged(); - return this; - } - /** - *
-     * The duration of the event in picoseconds if applicable.
-     * Events without duration are called instant events.
-     * 
- * - * uint64 duration_ps = 10; - */ - public Builder clearDurationPs() { - - durationPs_ = 0L; - onChanged(); - return this; - } - - private com.google.protobuf.MapField< - java.lang.String, java.lang.String> args_; - private com.google.protobuf.MapField - internalGetArgs() { - if (args_ == null) { - return com.google.protobuf.MapField.emptyMapField( - ArgsDefaultEntryHolder.defaultEntry); - } - return args_; - } - private com.google.protobuf.MapField - internalGetMutableArgs() { - onChanged();; - if (args_ == null) { - args_ = com.google.protobuf.MapField.newMapField( - ArgsDefaultEntryHolder.defaultEntry); - } - if (!args_.isMutable()) { - args_ = args_.copy(); - } - return args_; - } - - public int getArgsCount() { - return internalGetArgs().getMap().size(); - } - /** - *
-     * Extra arguments that will be displayed in trace view.
-     * 
- * - * map<string, string> args = 11; - */ - - public boolean containsArgs( - java.lang.String key) { - if (key == null) { throw new java.lang.NullPointerException(); } - return internalGetArgs().getMap().containsKey(key); - } - /** - * Use {@link #getArgsMap()} instead. - */ - @java.lang.Deprecated - public java.util.Map getArgs() { - return getArgsMap(); - } - /** - *
-     * Extra arguments that will be displayed in trace view.
-     * 
- * - * map<string, string> args = 11; - */ - - public java.util.Map getArgsMap() { - return internalGetArgs().getMap(); - } - /** - *
-     * Extra arguments that will be displayed in trace view.
-     * 
- * - * map<string, string> args = 11; - */ - - public java.lang.String getArgsOrDefault( - java.lang.String key, - java.lang.String defaultValue) { - if (key == null) { throw new java.lang.NullPointerException(); } - java.util.Map map = - internalGetArgs().getMap(); - return map.containsKey(key) ? map.get(key) : defaultValue; - } - /** - *
-     * Extra arguments that will be displayed in trace view.
-     * 
- * - * map<string, string> args = 11; - */ - - public java.lang.String getArgsOrThrow( - java.lang.String key) { - if (key == null) { throw new java.lang.NullPointerException(); } - java.util.Map map = - internalGetArgs().getMap(); - if (!map.containsKey(key)) { - throw new java.lang.IllegalArgumentException(); - } - return map.get(key); - } - - public Builder clearArgs() { - internalGetMutableArgs().getMutableMap() - .clear(); - return this; - } - /** - *
-     * Extra arguments that will be displayed in trace view.
-     * 
- * - * map<string, string> args = 11; - */ - - public Builder removeArgs( - java.lang.String key) { - if (key == null) { throw new java.lang.NullPointerException(); } - internalGetMutableArgs().getMutableMap() - .remove(key); - return this; - } - /** - * Use alternate mutation accessors instead. - */ - @java.lang.Deprecated - public java.util.Map - getMutableArgs() { - return internalGetMutableArgs().getMutableMap(); - } - /** - *
-     * Extra arguments that will be displayed in trace view.
-     * 
- * - * map<string, string> args = 11; - */ - public Builder putArgs( - java.lang.String key, - java.lang.String value) { - if (key == null) { throw new java.lang.NullPointerException(); } - if (value == null) { throw new java.lang.NullPointerException(); } - internalGetMutableArgs().getMutableMap() - .put(key, value); - return this; - } - /** - *
-     * Extra arguments that will be displayed in trace view.
-     * 
- * - * map<string, string> args = 11; - */ - - public Builder putAllArgs( - java.util.Map values) { - internalGetMutableArgs().getMutableMap() - .putAll(values); - return this; - } - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } - - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } - - - // @@protoc_insertion_point(builder_scope:tensorflow.profiler.TraceEvent) - } - - // @@protoc_insertion_point(class_scope:tensorflow.profiler.TraceEvent) - private static final org.tensorflow.proto.framework.TraceEvent DEFAULT_INSTANCE; - static { - DEFAULT_INSTANCE = new org.tensorflow.proto.framework.TraceEvent(); - } - - public static org.tensorflow.proto.framework.TraceEvent getDefaultInstance() { - return DEFAULT_INSTANCE; - } - - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public TraceEvent parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new TraceEvent(input, extensionRegistry); - } - }; - - public static com.google.protobuf.Parser parser() { - return PARSER; - } - - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } - - @java.lang.Override - public org.tensorflow.proto.framework.TraceEvent getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } - -} - diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/TraceEventOrBuilder.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/TraceEventOrBuilder.java deleted file mode 100644 index 2bebd63d083..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/TraceEventOrBuilder.java +++ /dev/null @@ -1,122 +0,0 @@ -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: tensorflow/core/protobuf/trace_events.proto - -package org.tensorflow.proto.framework; - -public interface TraceEventOrBuilder extends - // @@protoc_insertion_point(interface_extends:tensorflow.profiler.TraceEvent) - com.google.protobuf.MessageOrBuilder { - - /** - *
-   * The id of the device that this event occurred on. The full dataset should
-   * have this device present in the Trace object.
-   * 
- * - * uint32 device_id = 1; - */ - int getDeviceId(); - - /** - *
-   * The id of the resource that this event occurred on. The full dataset should
-   * have this resource present in the Device object of the Trace object. A
-   * resource_id is unique on a specific device, but not necessarily within the
-   * trace.
-   * 
- * - * uint32 resource_id = 2; - */ - int getResourceId(); - - /** - *
-   * The name of this trace event.
-   * 
- * - * string name = 3; - */ - java.lang.String getName(); - /** - *
-   * The name of this trace event.
-   * 
- * - * string name = 3; - */ - com.google.protobuf.ByteString - getNameBytes(); - - /** - *
-   * The timestamp that this event occurred at (in picos since tracing started).
-   * 
- * - * uint64 timestamp_ps = 9; - */ - long getTimestampPs(); - - /** - *
-   * The duration of the event in picoseconds if applicable.
-   * Events without duration are called instant events.
-   * 
- * - * uint64 duration_ps = 10; - */ - long getDurationPs(); - - /** - *
-   * Extra arguments that will be displayed in trace view.
-   * 
- * - * map<string, string> args = 11; - */ - int getArgsCount(); - /** - *
-   * Extra arguments that will be displayed in trace view.
-   * 
- * - * map<string, string> args = 11; - */ - boolean containsArgs( - java.lang.String key); - /** - * Use {@link #getArgsMap()} instead. - */ - @java.lang.Deprecated - java.util.Map - getArgs(); - /** - *
-   * Extra arguments that will be displayed in trace view.
-   * 
- * - * map<string, string> args = 11; - */ - java.util.Map - getArgsMap(); - /** - *
-   * Extra arguments that will be displayed in trace view.
-   * 
- * - * map<string, string> args = 11; - */ - - java.lang.String getArgsOrDefault( - java.lang.String key, - java.lang.String defaultValue); - /** - *
-   * Extra arguments that will be displayed in trace view.
-   * 
- * - * map<string, string> args = 11; - */ - - java.lang.String getArgsOrThrow( - java.lang.String key); -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/TraceEventsProtos.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/TraceEventsProtos.java deleted file mode 100644 index 4e2e8c8d80a..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/TraceEventsProtos.java +++ /dev/null @@ -1,133 +0,0 @@ -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: tensorflow/core/protobuf/trace_events.proto - -package org.tensorflow.proto.framework; - -public final class TraceEventsProtos { - private TraceEventsProtos() {} - public static void registerAllExtensions( - com.google.protobuf.ExtensionRegistryLite registry) { - } - - public static void registerAllExtensions( - com.google.protobuf.ExtensionRegistry registry) { - registerAllExtensions( - (com.google.protobuf.ExtensionRegistryLite) registry); - } - static final com.google.protobuf.Descriptors.Descriptor - internal_static_tensorflow_profiler_Trace_descriptor; - static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_tensorflow_profiler_Trace_fieldAccessorTable; - static final com.google.protobuf.Descriptors.Descriptor - internal_static_tensorflow_profiler_Trace_DevicesEntry_descriptor; - static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_tensorflow_profiler_Trace_DevicesEntry_fieldAccessorTable; - static final com.google.protobuf.Descriptors.Descriptor - internal_static_tensorflow_profiler_Device_descriptor; - static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_tensorflow_profiler_Device_fieldAccessorTable; - static final com.google.protobuf.Descriptors.Descriptor - internal_static_tensorflow_profiler_Device_ResourcesEntry_descriptor; - static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_tensorflow_profiler_Device_ResourcesEntry_fieldAccessorTable; - static final com.google.protobuf.Descriptors.Descriptor - internal_static_tensorflow_profiler_Resource_descriptor; - static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_tensorflow_profiler_Resource_fieldAccessorTable; - static final com.google.protobuf.Descriptors.Descriptor - internal_static_tensorflow_profiler_TraceEvent_descriptor; - static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_tensorflow_profiler_TraceEvent_fieldAccessorTable; - static final com.google.protobuf.Descriptors.Descriptor - internal_static_tensorflow_profiler_TraceEvent_ArgsEntry_descriptor; - static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_tensorflow_profiler_TraceEvent_ArgsEntry_fieldAccessorTable; - - public static com.google.protobuf.Descriptors.FileDescriptor - getDescriptor() { - return descriptor; - } - private static com.google.protobuf.Descriptors.FileDescriptor - descriptor; - static { - java.lang.String[] descriptorData = { - "\n+tensorflow/core/protobuf/trace_events." + - "proto\022\023tensorflow.profiler\"\305\001\n\005Trace\0228\n\007" + - "devices\030\001 \003(\0132\'.tensorflow.profiler.Trac" + - "e.DevicesEntry\0225\n\014trace_events\030\004 \003(\0132\037.t" + - "ensorflow.profiler.TraceEvent\032K\n\014Devices" + - "Entry\022\013\n\003key\030\001 \001(\r\022*\n\005value\030\002 \001(\0132\033.tens" + - "orflow.profiler.Device:\0028\001\"\271\001\n\006Device\022\014\n" + - "\004name\030\001 \001(\t\022\021\n\tdevice_id\030\002 \001(\r\022=\n\tresour" + - "ces\030\003 \003(\0132*.tensorflow.profiler.Device.R" + - "esourcesEntry\032O\n\016ResourcesEntry\022\013\n\003key\030\001" + - " \001(\r\022,\n\005value\030\002 \001(\0132\035.tensorflow.profile" + - "r.Resource:\0028\001\"-\n\010Resource\022\014\n\004name\030\001 \001(\t" + - "\022\023\n\013resource_id\030\002 \001(\r\"\323\001\n\nTraceEvent\022\021\n\t" + - "device_id\030\001 \001(\r\022\023\n\013resource_id\030\002 \001(\r\022\014\n\004" + - "name\030\003 \001(\t\022\024\n\014timestamp_ps\030\t \001(\004\022\023\n\013dura" + - "tion_ps\030\n \001(\004\0227\n\004args\030\013 \003(\0132).tensorflow" + - ".profiler.TraceEvent.ArgsEntry\032+\n\tArgsEn" + - "try\022\013\n\003key\030\001 \001(\t\022\r\n\005value\030\002 \001(\t:\0028\001B\202\001\n\036" + - "org.tensorflow.proto.frameworkB\021TraceEve" + - "ntsProtosP\001ZHgithub.com/tensorflow/tenso" + - "rflow/tensorflow/go/core/core_protos_go_" + - "proto\370\001\001b\006proto3" - }; - descriptor = com.google.protobuf.Descriptors.FileDescriptor - .internalBuildGeneratedFileFrom(descriptorData, - new com.google.protobuf.Descriptors.FileDescriptor[] { - }); - internal_static_tensorflow_profiler_Trace_descriptor = - getDescriptor().getMessageTypes().get(0); - internal_static_tensorflow_profiler_Trace_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( - internal_static_tensorflow_profiler_Trace_descriptor, - new java.lang.String[] { "Devices", "TraceEvents", }); - internal_static_tensorflow_profiler_Trace_DevicesEntry_descriptor = - internal_static_tensorflow_profiler_Trace_descriptor.getNestedTypes().get(0); - internal_static_tensorflow_profiler_Trace_DevicesEntry_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( - internal_static_tensorflow_profiler_Trace_DevicesEntry_descriptor, - new java.lang.String[] { "Key", "Value", }); - internal_static_tensorflow_profiler_Device_descriptor = - getDescriptor().getMessageTypes().get(1); - internal_static_tensorflow_profiler_Device_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( - internal_static_tensorflow_profiler_Device_descriptor, - new java.lang.String[] { "Name", "DeviceId", "Resources", }); - internal_static_tensorflow_profiler_Device_ResourcesEntry_descriptor = - internal_static_tensorflow_profiler_Device_descriptor.getNestedTypes().get(0); - internal_static_tensorflow_profiler_Device_ResourcesEntry_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( - internal_static_tensorflow_profiler_Device_ResourcesEntry_descriptor, - new java.lang.String[] { "Key", "Value", }); - internal_static_tensorflow_profiler_Resource_descriptor = - getDescriptor().getMessageTypes().get(2); - internal_static_tensorflow_profiler_Resource_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( - internal_static_tensorflow_profiler_Resource_descriptor, - new java.lang.String[] { "Name", "ResourceId", }); - internal_static_tensorflow_profiler_TraceEvent_descriptor = - getDescriptor().getMessageTypes().get(3); - internal_static_tensorflow_profiler_TraceEvent_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( - internal_static_tensorflow_profiler_TraceEvent_descriptor, - new java.lang.String[] { "DeviceId", "ResourceId", "Name", "TimestampPs", "DurationPs", "Args", }); - internal_static_tensorflow_profiler_TraceEvent_ArgsEntry_descriptor = - internal_static_tensorflow_profiler_TraceEvent_descriptor.getNestedTypes().get(0); - internal_static_tensorflow_profiler_TraceEvent_ArgsEntry_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( - internal_static_tensorflow_profiler_TraceEvent_ArgsEntry_descriptor, - new java.lang.String[] { "Key", "Value", }); - } - - // @@protoc_insertion_point(outer_class_scope) -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/TraceOrBuilder.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/TraceOrBuilder.java deleted file mode 100644 index 78fa25b0aa9..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/proto/framework/TraceOrBuilder.java +++ /dev/null @@ -1,112 +0,0 @@ -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: tensorflow/core/protobuf/trace_events.proto - -package org.tensorflow.proto.framework; - -public interface TraceOrBuilder extends - // @@protoc_insertion_point(interface_extends:tensorflow.profiler.Trace) - com.google.protobuf.MessageOrBuilder { - - /** - *
-   * The devices that this trace has information about. Maps from device_id to
-   * more data about the specific device.
-   * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - int getDevicesCount(); - /** - *
-   * The devices that this trace has information about. Maps from device_id to
-   * more data about the specific device.
-   * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - boolean containsDevices( - int key); - /** - * Use {@link #getDevicesMap()} instead. - */ - @java.lang.Deprecated - java.util.Map - getDevices(); - /** - *
-   * The devices that this trace has information about. Maps from device_id to
-   * more data about the specific device.
-   * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - java.util.Map - getDevicesMap(); - /** - *
-   * The devices that this trace has information about. Maps from device_id to
-   * more data about the specific device.
-   * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - - org.tensorflow.proto.framework.Device getDevicesOrDefault( - int key, - org.tensorflow.proto.framework.Device defaultValue); - /** - *
-   * The devices that this trace has information about. Maps from device_id to
-   * more data about the specific device.
-   * 
- * - * map<uint32, .tensorflow.profiler.Device> devices = 1; - */ - - org.tensorflow.proto.framework.Device getDevicesOrThrow( - int key); - - /** - *
-   * All trace events capturing in the profiling period.
-   * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - java.util.List - getTraceEventsList(); - /** - *
-   * All trace events capturing in the profiling period.
-   * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - org.tensorflow.proto.framework.TraceEvent getTraceEvents(int index); - /** - *
-   * All trace events capturing in the profiling period.
-   * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - int getTraceEventsCount(); - /** - *
-   * All trace events capturing in the profiling period.
-   * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - java.util.List - getTraceEventsOrBuilderList(); - /** - *
-   * All trace events capturing in the profiling period.
-   * 
- * - * repeated .tensorflow.profiler.TraceEvent trace_events = 4; - */ - org.tensorflow.proto.framework.TraceEventOrBuilder getTraceEventsOrBuilder( - int index); -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java index 96da6bc5ff4..80aa928a09a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java @@ -17,6 +17,7 @@ import org.bytedeco.javacpp.Pointer; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.family.TType; /** @@ -74,9 +75,9 @@ public String toString() { * Returns the datatype of the tensor of the {@code outputIdx}th output of this operation. * * @param outputIdx index of the output of this operation - * @return output tensor datatype + * @return datatype native code */ - abstract DataType dtype(int outputIdx); + abstract DataType dtype(int outputIdx); /** * Returns the tensor of the {@code outputIdx}th output of this operation. @@ -86,5 +87,5 @@ public String toString() { * @param outputIdx index of the output of this operation * @return output tensor */ - abstract Tensor tensor(int outputIdx); + abstract TensorHandle tensor(int outputIdx); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractTypeFactory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractTypeFactory.java new file mode 100644 index 00000000000..a1846f3dede --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractTypeFactory.java @@ -0,0 +1,12 @@ +package org.tensorflow; + +import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.types.TypeFactory; +import org.tensorflow.types.family.TType; + +public abstract class AbstractTypeFactory implements TypeFactory { + + protected static TF_Tensor getNative(TensorHandle tensorHandle) { + return tensorHandle.get(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 872b4b4d16d..de516eba81f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -16,14 +16,13 @@ package org.tensorflow; import java.io.IOException; -import java.util.List; -import java.util.ListIterator; -import java.util.HashMap; import java.util.Map; import java.util.function.Function; import org.tensorflow.op.Ops; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.types.family.TType; +import org.tensorflow.util.TensorMap; /** * A graph that can be invoked as a single function, with an input and output signature. @@ -163,14 +162,14 @@ public Signature signature() { * @return output tensors resulting from the execution of the function, * mapped by their signature name */ - public Map> call(Map> arguments) + public TensorMap call(Map arguments) throws IllegalArgumentException { final SignatureDef signatureDef = signature.asSignatureDef(); final Session.Runner runner = session.runner(); signatureDef.getInputsMap().forEach((argName, t) -> { - Tensor tensor = arguments.get(argName); + TType tensor = arguments.get(argName); if (tensor == null) { throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); } @@ -180,24 +179,7 @@ public Map> call(Map> arguments) Map outputToNode = signatureDef.getOutputsMap(); outputToNode.values().forEach(t -> runner.fetch(t.getName())); - List> resultTensors = runner.run(); - try { - ListIterator> resultTensorIter = resultTensors.listIterator(); - Map> returnMap = new HashMap>(); - - // Use the output names as present in the signature definition - for (String nodeName: outputToNode.keySet()) { - returnMap.put(nodeName, resultTensorIter.next()); - } - return returnMap; - - } catch (Exception e) { - // Release tensors before throwing exception - for (Tensor t : resultTensors) { - t.close(); - } - throw e; - } + return runner.run().toMap(outputToNode.keySet()); } /** @@ -210,7 +192,7 @@ public Map> call(Map> arguments) * @throws IllegalArgumentException if there are multiple input or output parameters defined * in the function */ - public Tensor call(Tensor tensor) throws IllegalArgumentException { + public T call(TType tensor) throws IllegalArgumentException { final SignatureDef signatureDef = signature.asSignatureDef(); if (signatureDef.getInputsCount() != 1) { @@ -225,7 +207,7 @@ public Tensor call(Tensor tensor) throws IllegalArgumentException { } String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName(); - return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); + return (T)session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java deleted file mode 100644 index 7b76b6dd02e..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java +++ /dev/null @@ -1,181 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow; - -import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.types.TBfloat16; -import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat16; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.TUint8; -import org.tensorflow.types.TString; -import org.tensorflow.types.family.TType; - -/** Represents a type of elements in a {@link Tensor} */ -public final class DataType { - - @FunctionalInterface - public interface TensorMapper { - - /** - * Maps tensor memory to a data structure for manipulating elements of this type. - * - * @param nativeTensor pointer to the native tensor - * @param shape the shape of the tensor - * @return data structure of elements of this type - */ - T apply(TF_Tensor nativeTensor, Shape shape); - } - - /** - * Creates a new datatype - * - * @param name readable-name for this type - * @param value must match the corresponding TF_* value in the TensorFlow C API. - * @param byteSize size of an element of this type, in bytes, -1 if unknown - * @param a tensor type - * @param tensorMapper method for mapping tensor memory to elements of this type - */ - public static DataType create( - String name, int value, int byteSize, TensorMapper tensorMapper) { - return new DataType<>(name, value, byteSize, tensorMapper); - } - - /** - * Gets the DataType associated with the readable-name for the type - *

The name must match exactly the name used to create the desired DataType

- * - * @param name readable-name for the type - * @return the DataType - * @throws java.lang.IllegalArgumentException if the name is not a valid data type name - * @throws java.lang.NullPointerException if name is null - */ - public static DataType of(String name) { - switch (name) { - case TBfloat16.NAME: - return TBfloat16.DTYPE; - case TFloat16.NAME: - return TFloat16.DTYPE; - case TFloat32.NAME: - return TFloat32.DTYPE; - case TFloat64.NAME: - return TFloat64.DTYPE; - case TUint8.NAME: - return TUint8.DTYPE; - case TInt32.NAME: - return TInt32.DTYPE; - case TInt64.NAME: - return TInt64.DTYPE; - case TBool.NAME: - return TBool.DTYPE; - case TString.NAME: - return TString.DTYPE; - default: - throw new IllegalArgumentException(String.format("%s is an unknown DataType", name)); - } - } - - /** Returns true if this data type represents a floating point type */ - public boolean isFloating() { - switch (this.name()) { - case TBfloat16.NAME: - case TFloat16.NAME: - case TFloat32.NAME: - case TFloat64.NAME: - return true; - default: - return false; - } - } - - /** Returns true if this data type represents an integer type */ - public boolean isInteger() { - switch (this.name()) { - case TInt32.NAME: - case TInt64.NAME: - case TUint8.NAME: - return true; - default: - return false; - } - } - - /** Returns true if this data type represents a numeric type */ - public boolean isNumeric() { - return isFloating() || isInteger(); - } - - /** Returns true if this data type represents a boolean type */ - public boolean isBoolean() { - return this.name().equals(TBool.NAME); - } - - /** Returns true if this data type represents a string type */ - public boolean isString() { - return this.name().equals(TString.NAME); - } - - /** Returns the size of an element of this type, in bytes, or -1 if element size is variable. */ - public int byteSize() { - return byteSize; - } - - /** Returns true if this datatype has elements of variable length */ - public boolean isVariableLength() { - return byteSize == -1; - } - - /** Returns a readable name for this type */ - public String name() { - return name; - } - - @Override - public String toString() { - return name + " (" + nativeCode + ")"; - } - - /** Returns the numeric code for this datatype, as recognized by the native library (C API) */ - int nativeCode() { - return nativeCode; - } - - /** - * Maps a tensor to a data structure for manipulating elements of this type. - * - * @param tensor tensor to map - * @return data structure of elements of this type - */ - T map(Tensor tensor) { - return tensorMapper.apply(tensor.nativeHandle(), tensor.shape()); - } - - private final int nativeCode; - private final int byteSize; - private final String name; - private final TensorMapper tensorMapper; - - private DataType(String name, int nativeCode, int byteSize, TensorMapper tensorMapper) { - this.name = name; - this.nativeCode = nativeCode; - this.byteSize = byteSize; - this.tensorMapper = tensorMapper; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java deleted file mode 100644 index 77c0de0c83f..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2019 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ======================================================================= - */ - -package org.tensorflow; - -import java.util.HashMap; -import java.util.Map; -import org.tensorflow.types.TBfloat16; -import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat16; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.TString; -import org.tensorflow.types.TUint8; - -/** - * Utility class for working with {@link DataType} objects. - */ -final class DataTypes { - - /** - * Find a data type from the type code returned by the native layer (C API). - * - *

Only data types registered via {@link #register(DataType)} can be resolved. - * - * @param nativeCode native code - * @return data type for this code - * @throws IllegalArgumentException if the code matches no registered data type - */ - static DataType fromNativeCode(int nativeCode) { - DataType dataType = DATA_TYPE_REGISTRY.get(nativeCode); - if (dataType == null) { - throw new IllegalArgumentException( - "DataType " + nativeCode + " is not recognized in Java (version " + TensorFlow.version() + ")"); - } - return dataType; - } - - private static final Map> DATA_TYPE_REGISTRY = new HashMap<>(); - - static { - register(TBool.DTYPE); - register(TFloat64.DTYPE); - register(TFloat32.DTYPE); - register(TFloat16.DTYPE); - register(TInt32.DTYPE); - register(TInt64.DTYPE); - register(TString.DTYPE); - register(TUint8.DTYPE); - register(TBfloat16.DTYPE); - } - - // TODO (karllessard): Right now this method is private but we might want to expose it - // to allow user to register custom data types? - private static void register(DataType dataType) { - DATA_TYPE_REGISTRY.put(dataType.nativeCode(), dataType); - DATA_TYPE_REGISTRY.put(dataType.nativeCode() + 100, dataType); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 012981ac59c..59da4d92dab 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -29,6 +29,7 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; /** * Implementation of an {@link Operation} executed eagerly. @@ -89,12 +90,6 @@ public TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) { @Override public Shape shape(int outputIndex) { - // If the tensor of this output has already been resolved, return its shape. - // Otherwise, retrieve the tensor shape from the native library. - Tensor tensor = outputTensors.get(outputIndex); - if (tensor != null) { - return tensor.shape(); - } TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle(outputIndex); long[] shape = new long[numDims(outputNativeHandle)]; for (int i = 0; i < shape.length; ++i) { @@ -104,20 +99,14 @@ public Shape shape(int outputIndex) { } @Override - public DataType dtype(int outputIndex) { - // If the tensor of this output has already been resolved, return its datatype. - // Otherwise, retrieve the tensor datatype from the native library. - Tensor tensor = outputTensors.get(outputIndex); - if (tensor != null) { - return tensor.dataType(); - } + public DataType dtype(int outputIndex) { TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle(outputIndex); - return DataTypes.fromNativeCode(dataType(outputNativeHandle)); + return DataType.forNumber(dataType(outputNativeHandle)); } @Override - public Tensor tensor(int outputIndex) { - Tensor tensor = outputTensors.get(outputIndex); + public TensorHandle tensor(int outputIndex) { + TensorHandle tensor = outputTensors.get(outputIndex); if (tensor == null) { tensor = resolveTensor(outputIndex); } @@ -127,15 +116,15 @@ public Tensor tensor(int outputIndex) { private final EagerSession session; private final String type; private final String name; - private final AtomicReferenceArray> outputTensors; + private final AtomicReferenceArray outputTensors; - private Tensor resolveTensor(int outputIndex) { + private TensorHandle resolveTensor(int outputIndex) { // Take an optimistic approach, where we attempt to resolve the output tensor without locking. // If another thread has resolved it meanwhile, release our copy and reuse the existing one // instead. - Tensor tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session); + TensorHandle tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session); if (!outputTensors.compareAndSet(outputIndex, null, tensor)) { - session.detach(tensor.nativeHandle()); + session.detach(tensor.get()); tensor = outputTensors.get(outputIndex); } return tensor; @@ -156,13 +145,15 @@ private static void requireTensorHandle(TFE_TensorHandle handle) { } } - private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) { + private static TensorHandle resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) { requireTensorHandle(handle); try (PointerScope scope = new PointerScope()) { TF_Status status = TF_Status.newStatus(); - TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator(); + TF_Tensor nativeTensor = TFE_TensorHandleResolve(handle, status).withDeallocator(); status.throwExceptionIfNotOK(); - return Tensor.fromHandle(tensor, session); + TensorHandle t = TensorHandle.of(nativeTensor); + t.attachTo(session); + return t; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index f14795df55a..c33d945c2d6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -48,6 +48,8 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.TypeRegistry; +import org.tensorflow.types.family.TType; /** * An {@link OperationBuilder} for building {@link Operation Operations} that are executed eagerly. @@ -159,29 +161,29 @@ public EagerOperationBuilder setAttr(String name, boolean[] values) { } @Override - public EagerOperationBuilder setAttr(String name, DataType value) { - setAttrType(opHandle, name, value.nativeCode()); + public EagerOperationBuilder setAttr(String name, Class type) { + setAttrType(opHandle, name, TypeRegistry.find(type).dataType().getNumber()); return this; } @Override - public EagerOperationBuilder setAttr(String name, DataType[] values) { - int[] c = new int[values.length]; - for (int i = 0; i < values.length; ++i) { - c[i] = values[i].nativeCode(); + public EagerOperationBuilder setAttr(String name, Class[] types) { + int[] c = new int[types.length]; + for (int i = 0; i < types.length; ++i) { + c[i] = TypeRegistry.find(types[i]).dataType().getNumber(); } setAttrTypeList(opHandle, name, c); return this; } @Override - public EagerOperationBuilder setAttr(String name, Tensor value) { - setAttrTensor(opHandle, name, value.nativeHandle()); + public EagerOperationBuilder setAttr(String name, TType value) { + setAttrTensor(opHandle, name, value.handle().get()); return this; } @Override - public EagerOperationBuilder setAttr(String name, Tensor[] values) { + public EagerOperationBuilder setAttr(String name, TType[] values) { // TODO (karllessard) could be supported by adding this attribute type in the eager C API throw new UnsupportedOperationException( "Tensor list attributes are not supported in eager mode"); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 142d481c04f..8a106983771 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -54,6 +54,7 @@ import org.tensorflow.proto.framework.GraphDef; import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** @@ -766,15 +767,15 @@ private static SaverDef addVariableSaver(Graph graph) { Ops tf = Ops.create(graph).withSubScope("save"); List varNames = new ArrayList<>(); - List> varOutputs = new ArrayList<>(); - List> varTypes = new ArrayList<>(); + List> varOutputs = new ArrayList<>(); + List> varTypes = new ArrayList<>(); for (Iterator iter = graph.operations(); iter.hasNext();) { Operation op = iter.next(); if (op.type().equals("VariableV2")) { varNames.add(op.name()); varOutputs.add(op.output(0)); - varTypes.add(op.output(0).dataType()); + varTypes.add(op.output(0).type()); } } @@ -783,7 +784,7 @@ private static SaverDef addVariableSaver(Graph graph) { Constant varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp))); Operand varSlices = tf.zerosLike(varNamesTensor); - Placeholder saveFilename = tf.placeholder(TString.DTYPE); + Placeholder saveFilename = tf.placeholder(TString.class); Save saveVariables = tf.train.save( saveFilename, varNamesTensor, @@ -798,7 +799,7 @@ private static SaverDef addVariableSaver(Graph graph) { ); List restoreOps = new ArrayList<>(varOutputs.size()); for (int i = 0; i < varOutputs.size(); ++i) { - restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i))); + restoreOps.add(tf.assign(varOutputs.get(i), (Operand)restoreVariables.tensors().get(i))); } NoOp restoreAll = tf.withControlDependencies(restoreOps).noOp(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java index 70cd31366ce..36e57d5325c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java @@ -31,6 +31,8 @@ import org.tensorflow.internal.c_api.TF_Output; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.family.TType; /** * Implementation for an {@link Operation} added as a node to a {@link Graph}. @@ -148,17 +150,17 @@ Shape shape(int outputIdx) { } @Override - DataType dtype(int outputIdx) { + DataType dtype(int outputIdx) { Graph.Reference r = graph.ref(); try { - return DataTypes.fromNativeCode(dtype(r.nativeHandle(), getUnsafeNativeHandle(), outputIdx)); + return DataType.forNumber(dtype(r.nativeHandle(), getUnsafeNativeHandle(), outputIdx)); } finally { r.close(); } } @Override - Tensor tensor(int outputIdx) { + TensorHandle tensor(int outputIdx) { throw new IllegalStateException("Graph tensors must be fetched by running a session"); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index 2ef5c9010a1..a10ce0af010 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -52,6 +52,8 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.TypeRegistry; +import org.tensorflow.types.family.TType; /** An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. */ public final class GraphOperationBuilder implements OperationBuilder { @@ -221,10 +223,10 @@ public GraphOperationBuilder setAttr(String name, boolean[] value) { } @Override - public GraphOperationBuilder setAttr(String name, DataType value) { + public GraphOperationBuilder setAttr(String name, Class type) { Graph.Reference r = graph.ref(); try { - setAttrType(unsafeNativeHandle, name, value.nativeCode()); + setAttrType(unsafeNativeHandle, name, TypeRegistry.find(type).dataType().getNumber()); } finally { r.close(); } @@ -232,10 +234,10 @@ public GraphOperationBuilder setAttr(String name, DataType value) { } @Override - public GraphOperationBuilder setAttr(String name, DataType[] value) { - int[] ctypes = new int[value.length]; - for (int i = 0; i < value.length; ++i) { - ctypes[i] = value[i].nativeCode(); + public GraphOperationBuilder setAttr(String name, Class[] types) { + int[] ctypes = new int[types.length]; + for (int i = 0; i < types.length; ++i) { + ctypes[i] = TypeRegistry.find(types[i]).dataType().getNumber(); } Graph.Reference r = graph.ref(); try { @@ -247,10 +249,10 @@ public GraphOperationBuilder setAttr(String name, DataType[] value) { } @Override - public GraphOperationBuilder setAttr(String name, Tensor value) { + public GraphOperationBuilder setAttr(String name, TType value) { Graph.Reference r = graph.ref(); try { - setAttrTensor(unsafeNativeHandle, name, value.nativeHandle()); + setAttrTensor(unsafeNativeHandle, name, value.handle().get()); } finally { r.close(); } @@ -258,11 +260,11 @@ public GraphOperationBuilder setAttr(String name, Tensor value) { } @Override - public GraphOperationBuilder setAttr(String name, Tensor[] value) { + public GraphOperationBuilder setAttr(String name, TType[] value) { TF_Tensor[] handles = new TF_Tensor[value.length]; int idx = 0; - for (Tensor t : value) { - handles[idx++] = t.nativeHandle(); + for (TType t : value) { + handles[idx++] = t.handle().get(); } Graph.Reference r = graph.ref(); try { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java index fa21f32d4ce..fcb630484d0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java @@ -15,7 +15,10 @@ package org.tensorflow; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.types.TypeRegistry; import org.tensorflow.types.family.TType; /** @@ -41,6 +44,17 @@ */ public interface Operand extends Op { + /** Returns the (possibly partially known) shape of the tensor referred to by this Output. */ + default Shape shape() { + return asOutput().shape(); + } + + /** Returns the DataType of the tensor referred to by this Output. */ + default Class type() { + return asOutput().type(); + } + + /** * Returns the symbolic handle of the tensor. * @@ -60,20 +74,15 @@ public interface Operand extends Op { * @return the tensor * @throws IllegalStateException if this is an operand of a graph */ - default Tensor asTensor() { + default T asTensor() { return asOutput().tensor(); } - /** - * Returns the data of this operand. - * - * Only works when running in an eager execution - *

This helper method is equivalent to {@code asTensor().data()} - * - * @return the tensor data - * @throws IllegalStateException if this is an operand of a graph - */ - default T data() { - return asOutput().tensor().data(); + default Operand expect(Class type) { + if (asOutput().type() != type) { + throw new IllegalArgumentException( + "Cannot cast from tensor of " + asOutput().type().getSimpleName() + " to tensor of " + type.getSimpleName()); + } + return (Operand)this; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java index af1b8cc9130..df6a80a79ca 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java @@ -16,6 +16,7 @@ package org.tensorflow; import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.family.TType; /** * A builder for {@link Operation}s. @@ -177,7 +178,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, DataType value); + OperationBuilder setAttr(String name, Class value); /** * Set the type values of an attribute of the operation being built. @@ -186,7 +187,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, DataType[] value); + OperationBuilder setAttr(String name, Class[] value); /** * Set the tensor value of an attribute of the operation being built. @@ -195,7 +196,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, Tensor value); + OperationBuilder setAttr(String name, TType value); /** * Set the tensor values of an attribute of the operation being built. @@ -204,7 +205,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, Tensor[] value); + OperationBuilder setAttr(String name, TType[] value); /** * Set the shape value of an attribute of the operation being built. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java index a873df8ff4c..7397d676393 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java @@ -18,6 +18,9 @@ import java.util.Objects; import org.bytedeco.javacpp.Pointer; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.types.Type; +import org.tensorflow.types.TypeRegistry; import org.tensorflow.types.family.TType; /** @@ -37,29 +40,30 @@ public int index() { } /** Returns the (possibly partially known) shape of the tensor referred to by this Output. */ + @Override public Shape shape() { return operation.shape(index); } /** Returns the DataType of the tensor referred to by this Output. */ - @SuppressWarnings("unchecked") - public DataType dataType() { - return (DataType)operation.dtype(index); + @Override + public Class type() { + return ((Type)TypeRegistry.find(operation.dtype(index))).typeClass(); } /** * Returns this Output object with the type {@code Output}. This method is useful when given a * value of type {@code Output}. * - * @param dt any supported tensor data type + * @param tensorType type of tensor at this output * @throws IllegalArgumentException if the actual data type of this object does not match the type * {@code U}. */ @SuppressWarnings("unchecked") - public Output expect(DataType dt) { - if (!dt.equals(this.dataType())) { + public Output expect(Class tensorType) { + if (tensorType != type()) { throw new IllegalArgumentException( - "Cannot cast from output of " + this.dataType() + " to output of " + dt); + "Cannot cast from output of " + type().getSimpleName() + " to output of type " + tensorType.getSimpleName()); } return ((Output) this); } @@ -80,8 +84,8 @@ public Output expect(DataType dt) { * @see EagerSession */ @SuppressWarnings("unchecked") - public Tensor tensor() { - return (Tensor) operation.tensor(index); + public T tensor() { + return Tensors.fromHandle(operation.tensor(index)); } @Override @@ -115,7 +119,7 @@ public boolean equals(Object o) { public String toString() { return String.format( "<%s '%s:%d' shape=%s dtype=%s>", - operation.type(), operation.name(), index, shape().toString(), dataType()); + operation.type(), operation.name(), index, shape().toString(), type().getSimpleName()); } /** Handle to the idx-th output of the Operation {@code op}. */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 093898ae56c..6a5ae0b6467 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -46,6 +46,8 @@ import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.proto.framework.SavedModel; import org.tensorflow.proto.util.SaverDef; +import org.tensorflow.types.family.TType; +import org.tensorflow.util.TensorMap; /** * SavedModelBundle represents a model loaded from storage. @@ -334,7 +336,7 @@ public ConcreteFunction function(String signatureKey) { * @return list of output tensors, mapped by the signature name * @throws IllegalArgumentException if no function can be selected by default */ - public Map> call(Map> arguments) { + public TensorMap call(Map arguments) { ConcreteFunction function = null; if (functions.size() == 1) { function = functions.values().iterator().next(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 4e82f3944b8..07e70db6dc9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -15,7 +15,15 @@ package org.tensorflow; +import static org.tensorflow.Graph.resolveOutputs; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; + import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayList; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -33,15 +41,11 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunMetadata; import org.tensorflow.proto.framework.RunOptions; - -import java.util.ArrayList; -import java.util.List; import org.tensorflow.proto.util.SaverDef; +import org.tensorflow.types.family.TType; +import org.tensorflow.util.TensorList; import org.tensorflow.types.TString; -import static org.tensorflow.Graph.resolveOutputs; -import static org.tensorflow.internal.c_api.global.tensorflow.*; - /** * Driver for {@link Graph} execution. * @@ -158,7 +162,7 @@ public final class Runner { * @param t the tensor substituting the operation * @return this session runner */ - public Runner feed(String operation, Tensor t) { + public Runner feed(String operation, TType t) { return feed(parseOutput(operation), t); } @@ -173,11 +177,11 @@ public Runner feed(String operation, Tensor t) { * @param t the tensor substituting the operation * @return this session runner */ - public Runner feed(String operation, int index, Tensor t) { + public Runner feed(String operation, int index, TType t) { Operation op = operationByName(operation); if (op != null) { inputs.add(op.output(index)); - inputTensors.add(t); + inputTensors.add(t.handle()); } return this; } @@ -190,9 +194,9 @@ public Runner feed(String operation, int index, Tensor t) { * @param t the tensor substituting the operation * @return this session runner */ - public Runner feed(Operand operand, Tensor t) { + public Runner feed(Operand operand, TType t) { inputs.add(operand.asOutput()); - inputTensors.add(t); + inputTensors.add(t.handle()); return this; } @@ -325,7 +329,7 @@ public Runner setOptions(RunOptions options) { * * @return list of resulting tensors fetched by this session runner */ - public List> run() { + public TensorList run() { return runHelper(false).outputs; } @@ -354,8 +358,8 @@ private Run runHelper(boolean wantMetadata) { // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the // validity of the Graph and graphRef ensures that. int idx = 0; - for (Tensor t : inputTensors) { - inputTensorHandles[idx++] = t.nativeHandle(); + for (TensorHandle t : inputTensors) { + inputTensorHandles[idx++] = t.get(); } idx = 0; for (Output o : inputs) { @@ -375,7 +379,7 @@ private Run runHelper(boolean wantMetadata) { } Reference runRef = new Reference(); RunMetadata metadata = null; - List> outputs = new ArrayList<>(); + TensorList outputs = new TensorList(); try { metadata = Session.run( @@ -390,10 +394,7 @@ private Run runHelper(boolean wantMetadata) { wantMetadata, outputs); } catch (Exception e) { - for (Tensor t : outputs) { - t.close(); - } - outputs.clear(); + outputs.close(); throw e; } finally { runRef.close(); @@ -451,7 +452,7 @@ private Output parseOutput(String opName) { } private ArrayList> inputs = new ArrayList<>(); - private ArrayList> inputTensors = new ArrayList<>(); + private ArrayList inputTensors = new ArrayList<>(); private ArrayList> outputs = new ArrayList<>(); private ArrayList targets = new ArrayList<>(); private RunOptions runOptions = null; @@ -518,7 +519,7 @@ public void save(String prefix) { */ public static final class Run { /** Tensors from requested fetches. */ - public List> outputs; + public TensorList outputs; /** * Metadata about the run. @@ -627,7 +628,7 @@ private static RunMetadata run( int[] outputOpIndices, TF_Operation[] targetOpHandles, boolean wantRunMetadata, - List> outputTensors) { + TensorList outputTensors) { requireHandle(handle); int ninputs = inputTensorHandles.length; @@ -667,7 +668,7 @@ private static RunMetadata run( for (int i = 0; i < noutputs; ++i) { TF_Tensor h = outputValues.get(TF_Tensor.class, i).withDeallocator(); - outputTensors.add(Tensor.fromHandle(h)); + outputTensors.add(Tensors.fromHandle(TensorHandle.of(h))); } try { return runMetadata != null ? RunMetadata.parseFrom(runMetadata.dataAsByteBuffer()) : null; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index 376dc9039fc..75785ece8ec 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -18,11 +18,11 @@ import java.util.Map; import java.util.Set; import org.tensorflow.ndarray.Shape; -import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.proto.framework.TensorShapeProto; import org.tensorflow.proto.framework.TensorShapeProto.Dim; +import org.tensorflow.types.TypeRegistry; /** * Describe the inputs and outputs of an executable entity, such as a {@link ConcreteFunction}, among @@ -113,7 +113,7 @@ private static TensorInfo toTensorInfo(Output operand) { tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i))); } return TensorInfo.newBuilder() - .setDtype(DataType.forNumber(operand.dataType().nativeCode())) + .setDtype(TypeRegistry.find(operand.type()).dataType()) .setTensorShape(tensorShapeBuilder) .setName(operand.op().name() + ":" + operand.index()) .build(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 6787713418f..4f47bd06dd2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -15,19 +15,9 @@ package org.tensorflow; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_Dim; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_NumDims; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorByteSize; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorType; - -import java.util.function.Consumer; -import org.bytedeco.javacpp.PointerScope; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.NdArrayBase; import org.tensorflow.ndarray.buffer.ByteDataBuffer; -import org.tensorflow.types.family.TType; +import org.tensorflow.proto.framework.DataType; /** * A statically typed multi-dimensional array whose elements are of a type described by T. @@ -44,167 +34,12 @@ * } * } */ -public final class Tensor implements AutoCloseable { - - /** - * Allocates a tensor of a given datatype and shape. - * - *

The amount of memory to allocate is derived from the datatype and the shape of the tensor. - * Memory is left uninitialized after this method returns, so it is the responsibility of the - * caller to initialize the tensor data before it is used, via the {@link #data()} accessor. - * For example: - * - *

{@code
-   * FloatNdArray data = ...
-   * try (Tensor t = Tensor.of(TFloat32.DTYPE, Shape.of(2, 2))) {
-   *   data.copyTo(t.data());
-   *   ...
-   * }
-   * }
- * - * @param the tensor element type - * @param dtype datatype of the tensor - * @param shape shape of the tensor - * @return an allocated but uninitialized tensor - * @throws IllegalStateException if tensor failed to be allocated - */ - public static Tensor of(DataType dtype, Shape shape) { - return of(dtype, shape, shape.size() * dtype.byteSize()); - } - - /** - * Allocates a tensor of a given datatype, shape and size. - * - *

This method is identical to {@link #of(DataType, Shape)}, except that the final size of the - * tensor is explicitly set instead of computing it from the datatype and shape. - * - *

This could be useful for tensor types that stores data but also metadata in the tensor memory, - * like {@link org.tensorflow.types.TString TString}. - * - * @param the tensor element type - * @param dtype datatype of the tensor - * @param shape shape of the tensor - * @param size size, in bytes, of the tensor - * @return an allocated but uninitialized tensor - * @see #of(DataType, Shape) - * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to - * store the tensor data - * @throws IllegalStateException if tensor failed to be allocated - */ - public static Tensor of(DataType dtype, Shape shape, long size) { - // Minimum requirements for datatypes of variable length cannot be verified in a relevant way so - // we only validate them for fixed length datatypes - if (!dtype.isVariableLength() && shape.size() * dtype.byteSize() > size) { - throw new IllegalArgumentException("Tensor size is not large enough to contain all scalar values"); - } - Tensor t = new Tensor<>(dtype, shape); - TF_Tensor nativeHandle = allocate(t.dtype.nativeCode(), shape.asArray(), size); - try (PointerScope scope = new PointerScope()) { - scope.attach(nativeHandle); - t.tensorHandle = nativeHandle; - t.tensorScope = scope.extend(); - return t; - } - } - - /** - * Allocates and initialize a tensor of a given datatype and shape. - * - *

The amount of memory to allocate is derived from the datatype and the shape of the tensor. - * Tensor data is initialized by calling the {@code dataInitializer}, which receives in argument - * the value returned by {@link #data()} on the allocated tensor. For example: - * - *

{@code
-   * FloatNdArray data = ...
-   * try (Tensor t = Tensor.of(TFloat32.DTYPE, Shape.of(2, 2), data::copyTo)) {
-   *   ...
-   * }
-   * }
- * - *

If {@code dataInitializer} fails and throws an exception, the allocated tensor will be - * automatically released before rethrowing the same exception. - * - * @param the tensor element type - * @param dtype datatype of the tensor - * @param shape shape of the tensor - * @param dataInitializer method receiving accessor to the allocated tensor data for initialization - * @return an allocated and initialized tensor - * @throws IllegalStateException if tensor failed to be allocated - */ - public static Tensor of(DataType dtype, Shape shape, - Consumer dataInitializer) { - return of(dtype, shape, shape.size() * dtype.byteSize(), dataInitializer); - } +public interface Tensor extends NdArrayBase, AutoCloseable { /** - * Allocates a tensor of a given datatype, shape and size. - * - *

This method is identical to {@link #of(DataType, Shape, Consumer)}, except that the final - * size for the tensor is explicitly set instead of being computed from the datatype and shape. - * - *

This could be useful for tensor types that stores data but also metadata in the tensor memory, - * such as {@link org.tensorflow.types.TString TString}. - * - * @param the tensor element type - * @param dtype datatype of the tensor - * @param shape shape of the tensor - * @param size size, in bytes, of the tensor - * @param dataInitializer method receiving accessor to the allocated tensor data for initialization - * @return an allocated and initialized tensor - * @see #of(DataType, Shape, long, Consumer) - * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to - * store the tensor data - * @throws IllegalStateException if tensor failed to be allocated + * Return the handle to the native tensor */ - public static Tensor of(DataType dtype, Shape shape, long size, - Consumer dataInitializer) { - Tensor tensor = of(dtype, shape, size); - try { - dataInitializer.accept(tensor.data()); - return tensor; - } catch (Throwable t) { - tensor.close(); - throw t; - } - } - - /** - * Creates a Tensor of any type from the raw data provided by the given buffer. - * - *

Data must have been encoded into {@code data} as per the specification of the TensorFlow C API. - * - * @param the tensor element type - * @param dtype the tensor element data type - * @param shape the tensor shape. - * @param rawData a buffer containing the tensor raw data. - * @throws IllegalArgumentException if {@code rawData} is not large enough to contain the tensor data - * @throws IllegalStateException if tensor failed to be allocated with the given parameters - */ - public static Tensor of(DataType dtype, Shape shape, ByteDataBuffer rawData) { - Tensor t = of(dtype, shape, rawData.size()); - rawData.copyTo(TensorBuffers.toBytes(t.nativeHandle()), rawData.size()); - return t; - } - - /** - * Returns this Tensor object with the type {@code Tensor}. This method is useful when given a - * value of type {@code Tensor}. - * - * @param dt any supported tensor data type - * @param a tensor type - * @return a tensor of the requested data type - * @throws IllegalArgumentException if the actual data type of this object does not match the type - * {@code U}. - */ - @SuppressWarnings("unchecked") - public Tensor expect(DataType dt) { - if (!dt.equals(this.dtype)) { - throw new IllegalArgumentException( - "Cannot cast from tensor of " + dtype + " to tensor of " + dt); - } - return ((Tensor) this); - } + TensorHandle handle(); /** * Release resources associated with the Tensor. @@ -215,73 +50,7 @@ public Tensor expect(DataType dt) { *

The Tensor object is no longer usable after {@code close} returns. */ @Override - public void close() { - tensorScope.close(); - } - - /** Returns the {@link DataType} of elements stored in the Tensor. */ - public DataType dataType() { - return dtype; - } - - /** Returns the size, in bytes, of the tensor data. */ - public long numBytes() { - if (numBytes == null) { - numBytes = TF_TensorByteSize(tensorHandle); - } - return numBytes; - } - - /** - * Returns the shape of - * the Tensor, i.e., the sizes of each dimension. - * - * @return shape of this tensor - */ - public Shape shape() { - return shape; - } - - /** - * Returns the data of this tensor. - * - *

This method returns an accessor to the tensor data as an instance of {@code T}, which - * commonly maps this data to an {@link NdArray NdArray}. Input and - * output operations performed on the returned n-dimensional array are applied directly to the - * tensor native memory. For example: - * - *

{@code
-   * Ops tf = Ops.create();
-   * try (Tensor t = TFloat32.tensorOf(Shape.of(2, 2))) {
-   *   TFloat32 data = t.data();
-   *
-   *   StdArrays.copyTo(data, new float[][] {
-   *     {1.0f, 2.0f},
-   *     {3.0f, 4.0f}
-   *   });
-   *   assertEquals(NdArrays.vectorOf(3.0f, 4.0f), data.getFloat(1));
-   *
-   *   Constant c = tf.constant(t);
-   *   assertEquals(4.0f, c.data().getFloat(1, 1));
-   * }
-   * }
- * - *

Please refer to the documentation of the {@link NdArray NdArray} - * classes for more information on the various techniques to read or write data in an - * n-dimensional space using this data structure. - * - * @return the tensor data mapped to an n-dimensional space - * @throws IllegalStateException if the tensor has been closed - * @see NdArray - */ - public T data() { - if (data == null) { - data = dtype.map(this); - } else { - nativeHandle(); // Checks that the tensor has not been released or will throw - } - return data; - } + void close(); /** * Returns the raw data of this tensor as a buffer of bytes. @@ -293,95 +62,13 @@ public T data() { * @return the tensor raw data mapped to a read-only byte buffer * @throws IllegalStateException if the tensor has been closed */ - public ByteDataBuffer rawData() { - return TensorBuffers.toBytes(nativeHandle(), true); - } - - /** Returns a string describing the type and shape of the Tensor. */ - @Override - public String toString() { - return String.format("%s tensor with shape %s", dtype.toString(), shape); - } - - /** - * Create a Tensor object from a handle to the C TF_Tensor object. - * - *

Takes ownership of the handle. - */ - static Tensor fromHandle(TF_Tensor handle) { - Tensor t = new Tensor<>(DataTypes.fromNativeCode(dtype(handle)), Shape.of(shape(handle))); - try (PointerScope scope = new PointerScope()) { - scope.attach(handle); - t.tensorHandle = handle; - t.tensorScope = scope.extend(); - } - return t; - } + ByteDataBuffer rawData(); - /** - * Create an eager Tensor object from a handle to the C TF_Tensor object. - * - *

Takes ownership of the handle. - */ - static Tensor fromHandle(TF_Tensor handle, EagerSession session) { - Tensor t = fromHandle(handle); - session.attach(handle); - t.tensorScope.detach(handle); - return t; - } - - /** - * @return native handle to this tensor - * @throws IllegalStateException if tensor has been closed - */ - TF_Tensor nativeHandle() { - return requireHandle(tensorHandle); - } - - private PointerScope tensorScope; - private TF_Tensor tensorHandle; - - private static TF_Tensor requireHandle(TF_Tensor handle) { - if (handle == null || handle.isNull()) { - throw new IllegalStateException("close() was called on the Tensor"); - } - return handle; - } - - private static TF_Tensor allocate(int dtype, long[] shape, long byteSize) { - TF_Tensor t = TF_Tensor.allocateTensor(dtype, shape, byteSize); - if (t == null || t.isNull()) { - throw new IllegalStateException("unable to allocate memory for the Tensor"); - } - return t; - } - - private static int dtype(TF_Tensor handle) { - requireHandle(handle); - return TF_TensorType(handle); - } + /** Returns the size, in bytes, of the tensor data. */ + long numBytes(); - private static long[] shape(TF_Tensor handle) { - requireHandle(handle); - int numDims = TF_NumDims(handle); - long[] dims = new long[numDims]; - for (int i = 0; i < numDims; ++i) { - dims[i] = TF_Dim(handle, i); - } - return dims; - } + DataType dataType(); +} - private final DataType dtype; - private final Shape shape; - private T data = null; - private Long numBytes = null; - private Tensor(DataType dtype, Shape shape) { - this.dtype = dtype; - this.shape = shape; - } - static { - TensorFlow.init(); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorHandle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorHandle.java new file mode 100644 index 00000000000..7ac9ac291a8 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorHandle.java @@ -0,0 +1,132 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import static org.tensorflow.internal.c_api.global.tensorflow.TF_Dim; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_NumDims; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorByteSize; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorElementCount; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorType; + +import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.proto.framework.DataType; + +public final class TensorHandle implements Tensor { + + @Override + public TensorHandle handle() { + return this; + } + + public void retain() { + requireHandle(nativeHandle).retainReference(); + } + + public void release() { + requireHandle(nativeHandle).releaseReference(); + } + + @Override + public Shape shape() { + return Shape.of(shape(nativeHandle)); + } + + @Override + public int rank() { + return numDims(nativeHandle); + } + + @Override + public long size() { + return numElements(nativeHandle); + } + + @Override + public void close() { + get().close(); + this.nativeHandle = null; + } + + @Override + public long numBytes() { + return TF_TensorByteSize(get()); + } + + @Override + public ByteDataBuffer rawData() { + return TensorBuffers.toBytes(get(), true); + } + + @Override + public DataType dataType() { + return DataType.forNumber(dtype(get())); + } + + static TensorHandle of(TF_Tensor nativeHandle) { + return new TensorHandle(nativeHandle); + } + + void attachTo(EagerSession session) { + session.attach(get()); + nativeHandle.releaseReference(); + } + + TF_Tensor get() { + return requireHandle(nativeHandle); + } + + private TensorHandle(TF_Tensor nativeHandle) { + this.nativeHandle = nativeHandle; + nativeHandle.retainReference(); + } + + private static TF_Tensor requireHandle(TF_Tensor handle) { + if (handle == null || handle.isNull()) { + throw new IllegalStateException("close() was called on the Tensor"); + } + return handle; + } + + private static int dtype(TF_Tensor handle) { + requireHandle(handle); + return TF_TensorType(handle); + } + + private static int numDims(TF_Tensor handle) { + requireHandle(handle); + return TF_NumDims(handle); + } + + private static long numElements(TF_Tensor handle) { + requireHandle(handle); + return TF_TensorElementCount(handle); + } + + private static long[] shape(TF_Tensor handle) { + requireHandle(handle); + int numDims = TF_NumDims(handle); + long[] dims = new long[numDims]; + for (int i = 0; i < numDims; ++i) { + dims[i] = TF_Dim(handle, i); + } + return dims; + } + + private TF_Tensor nativeHandle; +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensors.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensors.java new file mode 100644 index 00000000000..22c6a769e0f --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensors.java @@ -0,0 +1,201 @@ +package org.tensorflow; + +import static org.tensorflow.internal.c_api.global.tensorflow.TF_Dim; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_NumDims; + +import java.util.function.Consumer; +import org.bytedeco.javacpp.PointerScope; +import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.Type; +import org.tensorflow.types.TypeRegistry; +import org.tensorflow.types.family.TType; + +public final class Tensors { + + /** + * Allocates a tensor of a given datatype and shape. + * + *

The amount of memory to allocate is derived from the datatype and the shape of the tensor. + * Memory is left uninitialized after this method returns, so it is the responsibility of the + * caller to initialize the tensor data before it is used, via the {@link #data()} accessor. + * For example: + * + *

{@code
+   * FloatNdArray data = ...
+   * try (Tensor t = Tensor.of(TFloat32.DTYPE, Shape.of(2, 2))) {
+   *   data.copyTo(t.data());
+   *   ...
+   * }
+   * }
+ * + * @param the tensor element type + * @param type tensor type + * @param shape shape of the tensor + * @return an allocated but uninitialized tensor + * @throws IllegalStateException if tensor failed to be allocated + */ + public static T of(Class type, Shape shape) { + return of(type, shape, -1); + } + + /** + * Allocates a tensor of a given datatype, shape and size. + * + *

This method is identical to {@link #of(DataType, Shape)}, except that the final size of the + * tensor is explicitly set instead of computing it from the datatype and shape. + * + *

This could be useful for tensor types that stores data but also metadata in the tensor memory, + * like {@link org.tensorflow.types.TString TString}. + * + * @param the tensor element type + * @param type tensor type + * @param shape shape of the tensor + * @param size size, in bytes, of the tensor + * @return an allocated but uninitialized tensor + * @see #of(DataType, Shape) + * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to + * store the tensor data + * @throws IllegalStateException if tensor failed to be allocated + */ + public static T of(Class type, Shape shape, long size) { + return allocate(type, shape, size); + } + + /** + * Allocates and initialize a tensor of a given datatype and shape. + * + *

The amount of memory to allocate is derived from the datatype and the shape of the tensor. + * Tensor data is initialized by calling the {@code dataInitializer}, which receives in argument + * the value returned by {@link #data()} on the allocated tensor. For example: + * + *

{@code
+   * FloatNdArray data = ...
+   * try (Tensor t = Tensor.of(TFloat32.DTYPE, Shape.of(2, 2), data::copyTo)) {
+   *   ...
+   * }
+   * }
+ * + *

If {@code dataInitializer} fails and throws an exception, the allocated tensor will be + * automatically released before rethrowing the same exception. + * + * @param the tensor element type + * @param type tensor type + * @param shape shape of the tensor + * @param dataInitializer method receiving accessor to the allocated tensor data for initialization + * @return an allocated and initialized tensor + * @throws IllegalStateException if tensor failed to be allocated + */ + public static T of(Class type, Shape shape, Consumer dataInitializer) { + return of(type, shape, -1, dataInitializer); + } + + /** + * Allocates a tensor of a given datatype, shape and size. + * + *

This method is identical to {@link #of(DataType, Shape, Consumer)}, except that the final + * size for the tensor is explicitly set instead of being computed from the datatype and shape. + * + *

This could be useful for tensor types that stores data but also metadata in the tensor memory, + * such as {@link org.tensorflow.types.TString TString}. + * + * @param the tensor element type + * @param type tensor type + * @param shape shape of the tensor + * @param size size, in bytes, of the tensor + * @param dataInitializer method receiving accessor to the allocated tensor data for initialization + * @return an allocated and initialized tensor + * @see #of(DataType, Shape, long, Consumer) + * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to + * store the tensor data + * @throws IllegalStateException if tensor failed to be allocated + */ + public static T of(Class type, Shape shape, long size, Consumer dataInitializer) { + T tensor = of(type, shape, size); + try { + dataInitializer.accept(tensor); + return tensor; + } catch (Throwable t) { + tensor.close(); + throw t; + } + } + + /** + * Creates a Tensor of any type from the raw data provided by the given buffer. + * + *

Data must have been encoded into {@code data} as per the specification of the TensorFlow C API. + * + * @param the tensor element type + * @param type tensor type + * @param dtype the tensor element data type + * @param shape the tensor shape. + * @param rawData a buffer containing the tensor raw data. + * @throws IllegalArgumentException if {@code rawData} is not large enough to contain the tensor data + * @throws IllegalStateException if tensor failed to be allocated with the given parameters + */ + public static T of(Class type, Shape shape, ByteDataBuffer rawData) { + T t = of(type, shape, rawData.size()); + rawData.copyTo(TensorBuffers.toBytes(t.handle().get()), rawData.size()); + return t; + } + + /** + * Create a Tensor object from a handle to the C TF_Tensor object. + * + *

Takes ownership of the handle. + */ + static T fromHandle(TensorHandle handle) { + Type type = TypeRegistry.find(handle.dataType()); + Shape shape = Shape.of(shape(handle.get())); + return type.factory().createDense(handle, shape); + } + + private static T allocate(Class typeClass, Shape shape, long size) { + Type type = TypeRegistry.find(typeClass); + long effectiveSize = size; + if (effectiveSize < 0) { + // Size of the tensor is by default the sum of the size of all its element + effectiveSize = shape.size() * type.byteSize(); + + } else if (!type.isVariableLength() && shape.size() * type.byteSize() > effectiveSize) { + // Minimum requirements for datatypes of variable length cannot be verified in a relevant way + // so we only validate them for fixed length datatypes + throw new IllegalArgumentException("Tensor size is not large enough to contain all scalar values"); + } + TF_Tensor nativeHandle = allocate(type.dataType().getNumber(), shape.asArray(), effectiveSize); + try (PointerScope scope = new PointerScope()) { + scope.attach(nativeHandle); + return type.factory().createDense(TensorHandle.of(nativeHandle), shape); + } + } + + private static TF_Tensor requireHandle(TF_Tensor handle) { + if (handle == null || handle.isNull()) { + throw new IllegalStateException("close() was called on the Tensor"); + } + return handle; + } + + private static TF_Tensor allocate(int dtype, long[] shape, long byteSize) { + TF_Tensor t = TF_Tensor.allocateTensor(dtype, shape, byteSize); + if (t == null || t.isNull()) { + throw new IllegalStateException("unable to allocate memory for the Tensor"); + } + return t; + } + + private static long[] shape(TF_Tensor handle) { + requireHandle(handle); + int numDims = TF_NumDims(handle); + long[] dims = new long[numDims]; + for (int i = 0; i < numDims; ++i) { + dims[i] = TF_Dim(handle, i); + } + return dims; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/StringTensorBuffer.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java similarity index 91% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/StringTensorBuffer.java rename to tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java index 83cdab33452..bc80b2a164c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/StringTensorBuffer.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java @@ -50,7 +50,7 @@ *

After its data has been initialized, the buffer is read-only as it is not possible to change * safely a value without reinitializing the whole data. */ -public class StringTensorBuffer extends AbstractDataBuffer { +public class ByteSequenceTensorBuffer extends AbstractDataBuffer { /** * Computes how many bytes are required to store the given data in a string buffer. @@ -66,7 +66,7 @@ public static long computeSize(NdArray data, Function getBytes // reserve space to store length and data of each values for (NdArray scalar : data.scalars()) { byte[] elementBytes = getBytes.apply(scalar.getObject()); - size += elementBytes.length + StringTensorBuffer.varintLength(elementBytes.length); + size += elementBytes.length + ByteSequenceTensorBuffer.varintLength(elementBytes.length); } return size; } @@ -129,8 +129,8 @@ public boolean isReadOnly() { @Override public DataBuffer copyTo(DataBuffer dst, long size) { - if (size == size() && dst instanceof StringTensorBuffer) { - StringTensorBuffer tensorDst = (StringTensorBuffer) dst; + if (size == size() && dst instanceof ByteSequenceTensorBuffer) { + ByteSequenceTensorBuffer tensorDst = (ByteSequenceTensorBuffer) dst; if (offsets.size() != size || data.size() != size) { throw new IllegalArgumentException( "Cannot copy string tensor data to another tensor of a different size"); @@ -145,20 +145,20 @@ public DataBuffer copyTo(DataBuffer dst, long size) { @Override public DataBuffer offset(long index) { - return new StringTensorBuffer(offsets.offset(index), data); + return new ByteSequenceTensorBuffer(offsets.offset(index), data); } @Override public DataBuffer narrow(long size) { - return new StringTensorBuffer(offsets.narrow(size), data); + return new ByteSequenceTensorBuffer(offsets.narrow(size), data); } @Override public DataBuffer slice(long index, long size) { - return new StringTensorBuffer(offsets.slice(index, size), data); + return new ByteSequenceTensorBuffer(offsets.slice(index, size), data); } - StringTensorBuffer(LongDataBuffer offsets, ByteDataBuffer data) { + ByteSequenceTensorBuffer(LongDataBuffer offsets, ByteDataBuffer data) { this.offsets = offsets; this.data = data; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorBuffers.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorBuffers.java index f29396dd321..415c5ca35ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorBuffers.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorBuffers.java @@ -156,7 +156,7 @@ public static BooleanDataBuffer toBooleans(TF_Tensor nativeTensor) { * @param nativeTensor native reference to the tensor * @return a string buffer */ - public static StringTensorBuffer toStrings(TF_Tensor nativeTensor, long numElements) { + public static ByteSequenceTensorBuffer toStrings(TF_Tensor nativeTensor, long numElements) { Pointer tensorMemory = tensorMemory(nativeTensor); if (TensorRawDataBufferFactory.canBeUsed()) { return TensorRawDataBufferFactory.mapTensorToStrings(tensorMemory, numElements); @@ -173,7 +173,7 @@ public static StringTensorBuffer toStrings(TF_Tensor nativeTensor, long numEleme dataBuffer.position((int)numElements * Long.BYTES); ByteDataBuffer data = DataBuffers.of(dataBuffer.slice()); - return new StringTensorBuffer(offsets, data); + return new ByteSequenceTensorBuffer(offsets, data); } private static Pointer tensorMemory(TF_Tensor nativeTensor) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorRawDataBufferFactory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorRawDataBufferFactory.java index 1cfb1c9ab9a..dbaf31f1dcc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorRawDataBufferFactory.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorRawDataBufferFactory.java @@ -57,13 +57,13 @@ static BooleanDataBuffer mapTensorToBooleans(Pointer tensorMemory) { return mapNativeBooleans(tensorMemory.address(), tensorMemory.capacity(), false); } - static StringTensorBuffer mapTensorToStrings(Pointer tensorMemory, long numElements) { + static ByteSequenceTensorBuffer mapTensorToStrings(Pointer tensorMemory, long numElements) { long offsetByteSize = numElements * Long.BYTES; LongDataBuffer offsets = mapNativeLongs(tensorMemory.address(), offsetByteSize, false); ByteDataBuffer data = mapNativeBytes( tensorMemory.address() + offsetByteSize, tensorMemory.capacity() - offsetByteSize, false); - return new StringTensorBuffer(offsets, data); + return new ByteSequenceTensorBuffer(offsets, data); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Factory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Factory.java new file mode 100644 index 00000000000..50542dc0dc7 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Factory.java @@ -0,0 +1,42 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.AbstractTypeFactory; +import org.tensorflow.TensorHandle; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayouts; +import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; +import org.tensorflow.types.TBfloat16; + +/** + * Factory of {@link TBfloat16} tensor instances + */ +public class TBfloat16Factory extends AbstractTypeFactory { + + @Override + public TBfloat16 createDense(TensorHandle tensorHandle, Shape shape) { + FloatDataBuffer buffer = DataLayouts.BFLOAT16.applyTo(TensorBuffers.toShorts(getNative(tensorHandle))); + return new TBfloat16Impl(tensorHandle, buffer, shape); + } + + private static final class TBfloat16Impl extends FloatDenseNdArray implements TBfloat16 { + + @Override + public Class type() { + return TBfloat16.class; + } + + @Override + public TensorHandle handle() { + return tensorHandle; + } + + TBfloat16Impl(TensorHandle tensorHandle, FloatDataBuffer buffer, Shape shape) { + super(buffer, shape); + this.tensorHandle = tensorHandle; + } + + private final TensorHandle tensorHandle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolFactory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolFactory.java new file mode 100644 index 00000000000..1d1552cc291 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolFactory.java @@ -0,0 +1,41 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.AbstractTypeFactory; +import org.tensorflow.TensorHandle; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; +import org.tensorflow.types.TBool; + +/** + * Factory of {@link TBool} tensor instances + */ +public class TBoolFactory extends AbstractTypeFactory { + + @Override + public TBool createDense(TensorHandle tensorHandle, Shape shape) { + BooleanDataBuffer buffer = TensorBuffers.toBooleans(getNative(tensorHandle)); + return new TBoolImpl(tensorHandle, buffer, shape); + } + + private static final class TBoolImpl extends BooleanDenseNdArray implements TBool { + + @Override + public Class type() { + return TBool.class; + } + + @Override + public TensorHandle handle() { + return tensorHandle; + } + + TBoolImpl(TensorHandle tensorHandle, BooleanDataBuffer buffer, Shape shape) { + super(buffer, shape); + this.tensorHandle = tensorHandle; + } + + private final TensorHandle tensorHandle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Factory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Factory.java new file mode 100644 index 00000000000..73301f6894e --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Factory.java @@ -0,0 +1,42 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.AbstractTypeFactory; +import org.tensorflow.TensorHandle; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayouts; +import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; +import org.tensorflow.types.TFloat16; + +/** + * Factory of {@link TFloat16} tensor instances + */ +public class TFloat16Factory extends AbstractTypeFactory { + + @Override + public TFloat16 createDense(TensorHandle tensorHandle, Shape shape) { + FloatDataBuffer buffer = DataLayouts.FLOAT16.applyTo(TensorBuffers.toShorts(getNative(tensorHandle))); + return new TFloat16Impl(tensorHandle, buffer, shape); + } + + private static final class TFloat16Impl extends FloatDenseNdArray implements TFloat16 { + + @Override + public Class type() { + return TFloat16.class; + } + + @Override + public TensorHandle handle() { + return tensorHandle; + } + + TFloat16Impl(TensorHandle tensorHandle, FloatDataBuffer buffer, Shape shape) { + super(buffer, shape); + this.tensorHandle = tensorHandle; + } + + private final TensorHandle tensorHandle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Factory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Factory.java new file mode 100644 index 00000000000..ae3a53f2c8a --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Factory.java @@ -0,0 +1,41 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.AbstractTypeFactory; +import org.tensorflow.TensorHandle; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; +import org.tensorflow.types.TFloat32; + +/** + * Factory of {@link TFloat32} tensor instances + */ +public class TFloat32Factory extends AbstractTypeFactory { + + @Override + public TFloat32 createDense(TensorHandle tensorHandle, Shape shape) { + FloatDataBuffer buffer = TensorBuffers.toFloats(getNative(tensorHandle)); + return new TFloat32Impl(tensorHandle, buffer, shape); + } + + private static final class TFloat32Impl extends FloatDenseNdArray implements TFloat32 { + + @Override + public Class type() { + return TFloat32.class; + } + + @Override + public TensorHandle handle() { + return tensorHandle; + } + + TFloat32Impl(TensorHandle tensorHandle, FloatDataBuffer buffer, Shape shape) { + super(buffer, shape); + this.tensorHandle = tensorHandle; + } + + private final TensorHandle tensorHandle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Factory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Factory.java new file mode 100644 index 00000000000..495f743e8cc --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Factory.java @@ -0,0 +1,41 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.AbstractTypeFactory; +import org.tensorflow.TensorHandle; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; +import org.tensorflow.types.TFloat64; + +/** + * Factory of {@link TFloat64} tensor instances + */ +public class TFloat64Factory extends AbstractTypeFactory { + + @Override + public TFloat64 createDense(TensorHandle tensorHandle, Shape shape) { + DoubleDataBuffer buffer = TensorBuffers.toDoubles(getNative(tensorHandle)); + return new TFloat64Impl(tensorHandle, buffer, shape); + } + + private static final class TFloat64Impl extends DoubleDenseNdArray implements TFloat64 { + + @Override + public Class type() { + return TFloat64.class; + } + + @Override + public TensorHandle handle() { + return tensorHandle; + } + + TFloat64Impl(TensorHandle tensorHandle, DoubleDataBuffer buffer, Shape shape) { + super(buffer, shape); + this.tensorHandle = tensorHandle; + } + + private final TensorHandle tensorHandle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Factory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Factory.java new file mode 100644 index 00000000000..368f543d743 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Factory.java @@ -0,0 +1,41 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.AbstractTypeFactory; +import org.tensorflow.TensorHandle; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.impl.dense.IntDenseNdArray; +import org.tensorflow.types.TInt32; + +/** + * Factory of {@link TInt32} tensor instances + */ +public class TInt32Factory extends AbstractTypeFactory { + + @Override + public TInt32 createDense(TensorHandle tensorHandle, Shape shape) { + IntDataBuffer buffer = TensorBuffers.toInts(getNative(tensorHandle)); + return new TInt32Impl(tensorHandle, buffer, shape); + } + + private static final class TInt32Impl extends IntDenseNdArray implements TInt32 { + + @Override + public Class type() { + return TInt32.class; + } + + @Override + public TensorHandle handle() { + return tensorHandle; + } + + TInt32Impl(TensorHandle tensorHandle, IntDataBuffer buffer, Shape shape) { + super(buffer, shape); + this.tensorHandle = tensorHandle; + } + + private final TensorHandle tensorHandle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Factory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Factory.java new file mode 100644 index 00000000000..6a6fc0317ec --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Factory.java @@ -0,0 +1,41 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.AbstractTypeFactory; +import org.tensorflow.TensorHandle; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.impl.dense.LongDenseNdArray; +import org.tensorflow.types.TInt64; + +/** + * Factory of {@link TInt64} tensor instances + */ +public class TInt64Factory extends AbstractTypeFactory { + + @Override + public TInt64 createDense(TensorHandle tensorHandle, Shape shape) { + LongDataBuffer buffer = TensorBuffers.toLongs(getNative(tensorHandle)); + return new Int64DenseTensor(tensorHandle, buffer, shape); + } + + private static final class Int64DenseTensor extends LongDenseNdArray implements TInt64 { + + @Override + public Class type() { + return TInt64.class; + } + + @Override + public TensorHandle handle() { + return tensorHandle; + } + + Int64DenseTensor(TensorHandle tensorHandle, LongDataBuffer buffer, Shape shape) { + super(buffer, shape); + this.tensorHandle = tensorHandle; + } + + private final TensorHandle tensorHandle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringFactory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringFactory.java new file mode 100644 index 00000000000..562662579fe --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringFactory.java @@ -0,0 +1,78 @@ +package org.tensorflow.internal.types; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.function.Function; +import org.tensorflow.AbstractTypeFactory; +import org.tensorflow.TensorHandle; +import org.tensorflow.internal.buffer.ByteSequenceTensorBuffer; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayout; +import org.tensorflow.ndarray.buffer.layout.DataLayouts; +import org.tensorflow.ndarray.impl.dense.DenseNdArray; +import org.tensorflow.types.TString; + +/** + * Factory of {@link TString} tensor instances + */ +public class TStringFactory extends AbstractTypeFactory { + + @Override + public TString createDense(TensorHandle tensorHandle, Shape shape) { + ByteSequenceTensorBuffer buffer = TensorBuffers.toStrings(getNative(tensorHandle), shape.size()); + return new TStringImpl(tensorHandle, buffer, shape); + } + + private static final DataLayout, String> UTF_8_LAYOUT = + DataLayouts.ofStrings(StandardCharsets.UTF_8); + + private static final class TStringImpl extends DenseNdArray implements TString { + + @Override + public NdArray asBytes() { + return NdArrays.wrap(shape(), rawBuffer); + } + + @Override + public Class type() { + return TString.class; + } + + @Override + public TensorHandle handle() { + return tensorHandle; + } + + @Override + public TString using(Charset charset) { + return new TStringImpl(tensorHandle, rawBuffer, shape(), DataLayouts.ofStrings(charset)); + } + + @Override + public void write(NdArray src, Function getBytes) { + rawBuffer.init(src, getBytes); + } + + TStringImpl(TensorHandle tensorHandle, ByteSequenceTensorBuffer rawBuffer, Shape shape) { + this(tensorHandle, rawBuffer, shape, UTF_8_LAYOUT); + } + + private TStringImpl( + TensorHandle tensorHandle, + ByteSequenceTensorBuffer rawBuffer, + Shape shape, + DataLayout, String> layout + ) { + super(layout.applyTo(rawBuffer), shape); + this.rawBuffer = rawBuffer; + this.tensorHandle = tensorHandle; + } + + private final TensorHandle tensorHandle; + private final ByteSequenceTensorBuffer rawBuffer; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Factory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Factory.java new file mode 100644 index 00000000000..13d3e5bc663 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Factory.java @@ -0,0 +1,41 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.AbstractTypeFactory; +import org.tensorflow.TensorHandle; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.impl.dense.ByteDenseNdArray; +import org.tensorflow.types.TUint8; + +/** + * Factory of {@link TUint8} tensor instances + */ +public class TUint8Factory extends AbstractTypeFactory { + + @Override + public TUint8 createDense(TensorHandle tensorHandle, Shape shape) { + ByteDataBuffer buffer = TensorBuffers.toBytes(getNative(tensorHandle)); + return new TUint8Impl(tensorHandle, buffer, shape); + } + + private static final class TUint8Impl extends ByteDenseNdArray implements TUint8 { + + @Override + public Class type() { + return TUint8.class; + } + + @Override + public TensorHandle handle() { + return tensorHandle; + } + + TUint8Impl(TensorHandle tensorHandle, ByteDataBuffer buffer, Shape shape) { + super(buffer, shape); + this.tensorHandle = tensorHandle; + } + + private final TensorHandle tensorHandle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java index 6c214cc6819..3c2d34d3980 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java @@ -16,23 +16,10 @@ package org.tensorflow.op.core; import java.nio.charset.Charset; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.Output; -import org.tensorflow.Tensor; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.BooleanDataBuffer; -import org.tensorflow.ndarray.buffer.ByteDataBuffer; -import org.tensorflow.ndarray.buffer.DataBuffer; -import org.tensorflow.ndarray.buffer.DoubleDataBuffer; -import org.tensorflow.ndarray.buffer.FloatDataBuffer; -import org.tensorflow.ndarray.buffer.IntDataBuffer; -import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.Tensors; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.ByteNdArray; import org.tensorflow.ndarray.DoubleNdArray; @@ -41,7 +28,19 @@ import org.tensorflow.ndarray.LongNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; @@ -77,7 +76,7 @@ public final class Constant extends RawOp implements Operand */ @Endpoint public static Constant scalarOf(Scope scope, int data) { - try (Tensor value = TInt32.scalarOf(data)) { + try (TInt32 value = TInt32.scalarOf(data)) { return create(scope, value); } } @@ -92,7 +91,7 @@ public static Constant scalarOf(Scope scope, int data) { */ @Endpoint public static Constant vectorOf(Scope scope, int[] data) { - try (Tensor value = TInt32.vectorOf(data)) { + try (TInt32 value = TInt32.vectorOf(data)) { return create(scope, value); } } @@ -122,7 +121,7 @@ public static Constant arrayOf(Scope scope, int... data) { */ @Endpoint public static Constant tensorOf(Scope scope, int[][] data) { - try (Tensor value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -138,7 +137,7 @@ public static Constant tensorOf(Scope scope, int[][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, int[][][] data) { - try (Tensor value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -154,7 +153,7 @@ public static Constant tensorOf(Scope scope, int[][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, int[][][][] data) { - try (Tensor value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -170,7 +169,7 @@ public static Constant tensorOf(Scope scope, int[][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, int[][][][][] data) { - try (Tensor value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -186,7 +185,7 @@ public static Constant tensorOf(Scope scope, int[][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, int[][][][][][] data) { - try (Tensor value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -201,7 +200,7 @@ public static Constant tensorOf(Scope scope, int[][][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, IntNdArray data) { - try (Tensor value = TInt32.tensorOf(data)) { + try (TInt32 value = TInt32.tensorOf(data)) { return create(scope, value); } } @@ -217,7 +216,7 @@ public static Constant tensorOf(Scope scope, IntNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, IntDataBuffer data) { - try (Tensor value = TInt32.tensorOf(shape, data)) { + try (TInt32 value = TInt32.tensorOf(shape, data)) { return create(scope, value); } } @@ -231,7 +230,7 @@ public static Constant tensorOf(Scope scope, Shape shape, IntDataBuffer */ @Endpoint public static Constant scalarOf(Scope scope, float data) { - try (Tensor value = TFloat32.scalarOf(data)) { + try (TFloat32 value = TFloat32.scalarOf(data)) { return create(scope, value); } } @@ -246,7 +245,7 @@ public static Constant scalarOf(Scope scope, float data) { */ @Endpoint public static Constant vectorOf(Scope scope, float[] data) { - try (Tensor value = TFloat32.vectorOf(data)) { + try (TFloat32 value = TFloat32.vectorOf(data)) { return create(scope, value); } } @@ -276,7 +275,7 @@ public static Constant arrayOf(Scope scope, float... data) { */ @Endpoint public static Constant tensorOf(Scope scope, float[][] data) { - try (Tensor value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -292,7 +291,7 @@ public static Constant tensorOf(Scope scope, float[][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, float[][][] data) { - try (Tensor value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -308,7 +307,7 @@ public static Constant tensorOf(Scope scope, float[][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, float[][][][] data) { - try (Tensor value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -324,7 +323,7 @@ public static Constant tensorOf(Scope scope, float[][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, float[][][][][] data) { - try (Tensor value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -340,7 +339,7 @@ public static Constant tensorOf(Scope scope, float[][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, float[][][][][][] data) { - try (Tensor value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -355,7 +354,7 @@ public static Constant tensorOf(Scope scope, float[][][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, FloatNdArray data) { - try (Tensor value = TFloat32.tensorOf(data)) { + try (TFloat32 value = TFloat32.tensorOf(data)) { return create(scope, value); } } @@ -371,7 +370,7 @@ public static Constant tensorOf(Scope scope, FloatNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, FloatDataBuffer data) { - try (Tensor value = TFloat32.tensorOf(shape, data)) { + try (TFloat32 value = TFloat32.tensorOf(shape, data)) { return create(scope, value); } } @@ -385,7 +384,7 @@ public static Constant tensorOf(Scope scope, Shape shape, FloatDataBuf */ @Endpoint public static Constant scalarOf(Scope scope, double data) { - try (Tensor value = TFloat64.scalarOf(data)) { + try (TFloat64 value = TFloat64.scalarOf(data)) { return create(scope, value); } } @@ -400,7 +399,7 @@ public static Constant scalarOf(Scope scope, double data) { */ @Endpoint public static Constant vectorOf(Scope scope, double[] data) { - try (Tensor value = TFloat64.vectorOf(data)) { + try (TFloat64 value = TFloat64.vectorOf(data)) { return create(scope, value); } } @@ -430,7 +429,7 @@ public static Constant arrayOf(Scope scope, double... data) { */ @Endpoint public static Constant tensorOf(Scope scope, double[][] data) { - try (Tensor value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -446,7 +445,7 @@ public static Constant tensorOf(Scope scope, double[][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, double[][][] data) { - try (Tensor value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -462,7 +461,7 @@ public static Constant tensorOf(Scope scope, double[][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, double[][][][] data) { - try (Tensor value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -478,7 +477,7 @@ public static Constant tensorOf(Scope scope, double[][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, double[][][][][] data) { - try (Tensor value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -494,7 +493,7 @@ public static Constant tensorOf(Scope scope, double[][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, double[][][][][][] data) { - try (Tensor value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -509,7 +508,7 @@ public static Constant tensorOf(Scope scope, double[][][][][][] data) */ @Endpoint public static Constant tensorOf(Scope scope, DoubleNdArray data) { - try (Tensor value = TFloat64.tensorOf(data)) { + try (TFloat64 value = TFloat64.tensorOf(data)) { return create(scope, value); } } @@ -525,7 +524,7 @@ public static Constant tensorOf(Scope scope, DoubleNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, DoubleDataBuffer data) { - try (Tensor value = TFloat64.tensorOf(shape, data)) { + try (TFloat64 value = TFloat64.tensorOf(shape, data)) { return create(scope, value); } } @@ -539,7 +538,7 @@ public static Constant tensorOf(Scope scope, Shape shape, DoubleDataBu */ @Endpoint public static Constant scalarOf(Scope scope, long data) { - try (Tensor value = TInt64.scalarOf(data)) { + try (TInt64 value = TInt64.scalarOf(data)) { return create(scope, value); } } @@ -554,7 +553,7 @@ public static Constant scalarOf(Scope scope, long data) { */ @Endpoint public static Constant vectorOf(Scope scope, long[] data) { - try (Tensor value = TInt64.vectorOf(data)) { + try (TInt64 value = TInt64.vectorOf(data)) { return create(scope, value); } } @@ -569,7 +568,7 @@ public static Constant vectorOf(Scope scope, long[] data) { */ @Endpoint public static Constant tensorOf(Scope scope, long[][] data) { - try (Tensor value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -600,7 +599,7 @@ public static Constant arrayOf(Scope scope, long... data) { */ @Endpoint public static Constant tensorOf(Scope scope, long[][][] data) { - try (Tensor value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -616,7 +615,7 @@ public static Constant tensorOf(Scope scope, long[][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, long[][][][] data) { - try (Tensor value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -632,7 +631,7 @@ public static Constant tensorOf(Scope scope, long[][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, long[][][][][] data) { - try (Tensor value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -648,7 +647,7 @@ public static Constant tensorOf(Scope scope, long[][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, long[][][][][][] data) { - try (Tensor value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -663,7 +662,7 @@ public static Constant tensorOf(Scope scope, long[][][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, LongNdArray data) { - try (Tensor value = TInt64.tensorOf(data)) { + try (TInt64 value = TInt64.tensorOf(data)) { return create(scope, value); } } @@ -679,7 +678,7 @@ public static Constant tensorOf(Scope scope, LongNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, LongDataBuffer data) { - try (Tensor value = TInt64.tensorOf(shape, data)) { + try (TInt64 value = TInt64.tensorOf(shape, data)) { return create(scope, value); } } @@ -693,7 +692,7 @@ public static Constant tensorOf(Scope scope, Shape shape, LongDataBuffer */ @Endpoint public static Constant scalarOf(Scope scope, boolean data) { - try (Tensor value = TBool.scalarOf(data)) { + try (TBool value = TBool.scalarOf(data)) { return create(scope, value); } } @@ -708,7 +707,7 @@ public static Constant scalarOf(Scope scope, boolean data) { */ @Endpoint public static Constant vectorOf(Scope scope, boolean[] data) { - try (Tensor value = TBool.vectorOf(data)) { + try (TBool value = TBool.vectorOf(data)) { return create(scope, value); } } @@ -738,7 +737,7 @@ public static Constant arrayOf(Scope scope, boolean... data) { */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][] data) { - try (Tensor value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -754,7 +753,7 @@ public static Constant tensorOf(Scope scope, boolean[][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][] data) { - try (Tensor value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -770,7 +769,7 @@ public static Constant tensorOf(Scope scope, boolean[][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][][] data) { - try (Tensor value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -786,7 +785,7 @@ public static Constant tensorOf(Scope scope, boolean[][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][][][] data) { - try (Tensor value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -802,7 +801,7 @@ public static Constant tensorOf(Scope scope, boolean[][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][][][][] data) { - try (Tensor value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -817,7 +816,7 @@ public static Constant tensorOf(Scope scope, boolean[][][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, BooleanNdArray data) { - try (Tensor value = TBool.tensorOf(data)) { + try (TBool value = TBool.tensorOf(data)) { return create(scope, value); } } @@ -833,7 +832,7 @@ public static Constant tensorOf(Scope scope, BooleanNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, BooleanDataBuffer data) { - try (Tensor value = TBool.tensorOf(shape, data)) { + try (TBool value = TBool.tensorOf(shape, data)) { return create(scope, value); } } @@ -847,7 +846,7 @@ public static Constant tensorOf(Scope scope, Shape shape, BooleanDataBuff */ @Endpoint public static Constant scalarOf(Scope scope, byte data) { - try (Tensor value = TUint8.scalarOf(data)) { + try (TUint8 value = TUint8.scalarOf(data)) { return create(scope, value); } } @@ -862,7 +861,7 @@ public static Constant scalarOf(Scope scope, byte data) { */ @Endpoint public static Constant vectorOf(Scope scope, byte[] data) { - try (Tensor value = TUint8.vectorOf(data)) { + try (TUint8 value = TUint8.vectorOf(data)) { return create(scope, value); } } @@ -892,7 +891,7 @@ public static Constant arrayOf(Scope scope, byte... data) { */ @Endpoint public static Constant tensorOf(Scope scope, byte[][] data) { - try (Tensor value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) { return create(scope, value); } @@ -908,7 +907,7 @@ public static Constant tensorOf(Scope scope, byte[][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][] data) { - try (Tensor value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) { return create(scope, value); } @@ -924,7 +923,7 @@ public static Constant tensorOf(Scope scope, byte[][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][][] data) { - try (Tensor value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) { return create(scope, value); } @@ -940,7 +939,7 @@ public static Constant tensorOf(Scope scope, byte[][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][][][] data) { - try (Tensor value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) { return create(scope, value); } @@ -956,7 +955,7 @@ public static Constant tensorOf(Scope scope, byte[][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][][][][] data) { - try (Tensor value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) { return create(scope, value); } @@ -971,7 +970,7 @@ public static Constant tensorOf(Scope scope, byte[][][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, ByteNdArray data) { - try (Tensor value = TUint8.tensorOf(data)) { + try (TUint8 value = TUint8.tensorOf(data)) { return create(scope, value); } } @@ -987,7 +986,7 @@ public static Constant tensorOf(Scope scope, ByteNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, ByteDataBuffer data) { - try (Tensor value = TUint8.tensorOf(shape, data)) { + try (TUint8 value = TUint8.tensorOf(shape, data)) { return create(scope, value); } } @@ -996,7 +995,7 @@ public static Constant tensorOf(Scope scope, Shape shape, ByteDataBuffer * Create a constant with data from the given buffer. * * @param scope is a scope used to add the underlying operation. - * @param type the tensor datatype. + * @param type the tensor type. * @param shape the tensor shape. * @param data a buffer containing the tensor data. * @return a constant of type `type` @@ -1004,9 +1003,9 @@ public static Constant tensorOf(Scope scope, Shape shape, ByteDataBuffer * buffer */ @Endpoint - public static Constant tensorOf(Scope scope, DataType type, Shape shape, + public static Constant tensorOf(Scope scope, Class type, Shape shape, ByteDataBuffer data) { - try (Tensor value = Tensor.of(type, shape, data)) { + try (T value = Tensors.of(type, shape, data)) { return create(scope, value); } } @@ -1020,7 +1019,7 @@ public static Constant tensorOf(Scope scope, DataType ty */ @Endpoint public static Constant scalarOf(Scope scope, String data) { - try (Tensor value = TString.scalarOf(data)) { + try (TString value = TString.scalarOf(data)) { return create(scope, value); } } @@ -1035,7 +1034,7 @@ public static Constant scalarOf(Scope scope, String data) { */ @Endpoint public static Constant scalarOf(Scope scope, Charset charset, String data) { - try (Tensor value = TString.tensorOf(charset, NdArrays.scalarOfObject(data))) { + try (TString value = TString.tensorOf(charset, NdArrays.scalarOfObject(data))) { return create(scope, value); } } @@ -1049,7 +1048,7 @@ public static Constant scalarOf(Scope scope, Charset charset, String da */ public static Constant vectorOf(Scope scope, String[] data) { NdArray src = NdArrays.vectorOfObjects(data); - try (Tensor value = TString.tensorOf(src)) { + try (TString value = TString.tensorOf(src)) { return create(scope, value); } } @@ -1065,7 +1064,7 @@ public static Constant vectorOf(Scope scope, String[] data) { */ @Endpoint public static Constant vectorOf(Scope scope, Charset charset, String[] data) { - try (Tensor value = TString.tensorOf(charset, NdArrays.vectorOfObjects(data))) { + try (TString value = TString.tensorOf(charset, NdArrays.vectorOfObjects(data))) { return Constant.create(scope, value); } } @@ -1112,7 +1111,7 @@ public static Constant arrayOf(Scope scope, Charset charset, String... public static Constant tensorOf(Scope scope, String[][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (Tensor value = TString.tensorOf(src)) { + try (TString value = TString.tensorOf(src)) { return create(scope, value); } } @@ -1127,7 +1126,7 @@ public static Constant tensorOf(Scope scope, String[][] data) { public static Constant tensorOf(Scope scope, String[][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (Tensor value = TString.tensorOf(src)) { + try (TString value = TString.tensorOf(src)) { return create(scope, value); } } @@ -1142,7 +1141,7 @@ public static Constant tensorOf(Scope scope, String[][][] data) { public static Constant tensorOf(Scope scope, String[][][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (Tensor value = TString.tensorOf(src)) { + try (TString value = TString.tensorOf(src)) { return create(scope, value); } } @@ -1157,7 +1156,7 @@ public static Constant tensorOf(Scope scope, String[][][][] data) { public static Constant tensorOf(Scope scope, String[][][][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (Tensor value = TString.tensorOf(src)) { + try (TString value = TString.tensorOf(src)) { return create(scope, value); } } @@ -1172,7 +1171,7 @@ public static Constant tensorOf(Scope scope, String[][][][][] data) { public static Constant tensorOf(Scope scope, String[][][][][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (Tensor value = TString.tensorOf(src)) { + try (TString value = TString.tensorOf(src)) { return create(scope, value); } } @@ -1187,7 +1186,7 @@ public static Constant tensorOf(Scope scope, String[][][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, NdArray data) { - try (Tensor value = TString.tensorOf(data)) { + try (TString value = TString.tensorOf(data)) { return create(scope, value); } } @@ -1203,7 +1202,7 @@ public static Constant tensorOf(Scope scope, NdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Charset charset, NdArray data) { - try (Tensor value = TString.tensorOf(charset, data)) { + try (TString value = TString.tensorOf(charset, data)) { return create(scope, value); } } @@ -1220,7 +1219,7 @@ public static Constant tensorOf(Scope scope, Charset charset, NdArray tensorOf(Scope scope, Shape shape, DataBuffer data) { - try (Tensor value = TString.tensorOf(shape, data)) { + try (TString value = TString.tensorOf(shape, data)) { return create(scope, value); } } @@ -1238,7 +1237,7 @@ public static Constant tensorOf(Scope scope, Shape shape, DataBuffer tensorOf(Scope scope, Charset charset, Shape shape, DataBuffer data) { - try (Tensor value = TString.tensorOf(charset, shape, data)) { + try (TString value = TString.tensorOf(charset, shape, data)) { return create(scope, value); } } @@ -1264,13 +1263,13 @@ public static Constant tensorOf(Scope scope, Shape shape) { * @return a constant of the same data type as `tensor` */ @Endpoint - public static Constant create(Scope scope, Tensor tensor) { + public static Constant create(Scope scope, T tensor) { return new Constant<>( scope .env() .opBuilder("Const", scope.makeOpName("Const")) .setAttr("value", tensor) - .setAttr("dtype", tensor.dataType()) + .setAttr("dtype", tensor.type()) .build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Gradients.java index 2827276c32c..45ac0d1180e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Gradients.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Gradients.java @@ -22,7 +22,6 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.op.Op; import org.tensorflow.op.Operands; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -101,9 +100,11 @@ public static Gradients create( } } } - Output[] dy = - graph.addGradients( - scope.makeOpName("Gradients"), Operands.asOutputs(y), Operands.asOutputs(x), dx); + Output[] dy = graph.addGradients( + scope.makeOpName("Gradients"), + Operands.asOutputs(y), + Operands.asOutputs(x), dx + ); return new Gradients(Arrays.asList(dy)); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java index f9ce837fe60..6a3c3521997 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java @@ -47,7 +47,7 @@ private Helpers() {} @Endpoint(name = "variable") public static Variable createVariableWithInit(Scope scope, Operand init, Variable.Options... options) { Output initOutput = init.asOutput(); - Variable newVar = Variable.create(scope,initOutput.shape(), initOutput.dataType(), options); + Variable newVar = Variable.create(scope,initOutput.shape(), initOutput.type(), options); Assign assignOp = Assign.create(scope, newVar, init); Init.add(scope, assignOp); return newVar; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Shapes.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Shapes.java index 613cb729341..a85fdbc1ab2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Shapes.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Shapes.java @@ -15,16 +15,12 @@ package org.tensorflow.op.core; import java.util.Arrays; - -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; - -import org.tensorflow.op.math.FloorMod; - import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.FloorMod; import org.tensorflow.op.math.NotEqual; import org.tensorflow.op.math.Sub; import org.tensorflow.types.TBool; @@ -51,8 +47,8 @@ * Operand numPred = tf.shape.size(predShape, tf.constant(0)); * Operand predFlat = tf.shape.flatten(yPred); * - * Shape predShape64 = tf.shape(yPred, TInt64.DTYPE); - * Operand predSqueezed = tf.shape.squeeze(predShape64, TInt64.DTYPE); + * Shape predShape64 = tf.shape(yPred, TInt64.class); + * Operand predSqueezed = tf.shape.squeeze(predShape64, TInt64.class); * } */ @Operator(group = "shape") @@ -68,7 +64,7 @@ public abstract class Shapes { */ @Endpoint(name = "flatten") public static Operand flatten(Scope scope, Operand operand) { - return flatten(scope, operand, TInt32.DTYPE); + return flatten(scope, operand, TInt32.class); } /** @@ -83,7 +79,7 @@ public static Operand flatten(Scope scope, Operand opera */ @Endpoint(name = "flatten") public static Operand flatten( - Scope scope, Operand operand, DataType dType) { + Scope scope, Operand operand, Class dType) { Operand flatShape = flatten(scope, Shape.create(scope, operand, dType), dType); return Reshape.create(scope, operand, flatShape); } @@ -97,7 +93,7 @@ public static Operand flatten( */ @Endpoint(name = "flatten") public static Operand flatten(Scope scope, Shape shape) { - return flatten(scope, shape, TInt32.DTYPE); + return flatten(scope, shape, TInt32.class); } /** @@ -111,11 +107,11 @@ public static Operand flatten(Scope scope, Shape shape) { */ @Endpoint(name = "flatten") public static Operand flatten( - Scope scope, Shape shape, DataType dType) { + Scope scope, Shape shape, Class dType) { return ExpandDims.create( scope, size(scope, shape, dType), - Cast.create(scope, Constant.scalarOf(scope, -1), TInt32.DTYPE)); + Cast.create(scope, Constant.scalarOf(scope, -1), TInt32.class)); } /** @@ -127,7 +123,7 @@ public static Operand flatten( */ @Endpoint(name = "size") public static Operand size(Scope scope, Shape shape) { - return size(scope, shape, TInt32.DTYPE); + return size(scope, shape, TInt32.class); } /** @@ -141,7 +137,7 @@ public static Operand size(Scope scope, Shape shape) { */ @Endpoint(name = "size") public static Operand size( - Scope scope, Shape shape, DataType dType) { + Scope scope, Shape shape, Class dType) { Slice dims = Slice.create( scope, @@ -164,7 +160,7 @@ public static Operand size( */ @Endpoint(name = "size") public static Operand size(Scope scope, Shape shape, Operand dim) { - return size(scope, shape, dim, TInt32.DTYPE); + return size(scope, shape, dim, TInt32.class); } /** @@ -179,7 +175,7 @@ public static Operand size(Scope scope, Shape shape, Operand Operand size( - Scope scope, Shape shape, Operand dim, DataType dType) { + Scope scope, Shape shape, Operand dim, Class dType) { return Slice.create( scope, shape, @@ -201,7 +197,7 @@ public static Operand size( @Endpoint(name = "size") public static Operand size( Scope scope, Operand input, Operand dim) { - return size(scope, input, dim, TInt32.DTYPE); + return size(scope, input, dim, TInt32.class); } /** @@ -216,7 +212,7 @@ public static Operand size( */ @Endpoint(name = "size") public static Operand size( - Scope scope, Operand input, Operand dim, DataType dType) { + Scope scope, Operand input, Operand dim, Class dType) { return size(scope, Shape.create(scope, input, dType), dim, dType); } @@ -229,7 +225,7 @@ public static Operand size( */ @Endpoint(name = "numDimensions") public static Operand numDimensions(Scope scope, Shape shape) { - return Size.create(scope, shape, TInt32.DTYPE); + return Size.create(scope, shape, TInt32.class); } /** @@ -243,7 +239,7 @@ public static Operand numDimensions(Scope scope, Shape shape) { */ @Endpoint(name = "numDimensions") public static Operand numDimensions( - Scope scope, Shape shape, DataType dType) { + Scope scope, Shape shape, Class dType) { return Size.create(scope, shape, dType); } @@ -259,7 +255,7 @@ public static Operand numDimensions( @Endpoint(name = "reduceDims") public static Operand reduceDims( Scope scope, Operand operand, Operand axis) { - return reduceDims(scope, operand, axis, TInt32.DTYPE); + return reduceDims(scope, operand, axis, TInt32.class); } /** @@ -275,7 +271,7 @@ public static Operand reduceDims( */ @Endpoint(name = "reduceDims") public static Operand reduceDims( - Scope scope, Operand operand, Operand axis, DataType dType) { + Scope scope, Operand operand, Operand axis, Class dType) { Shape newShape = Shape.create(scope, operand, dType); return Reshape.create(scope, operand, reduceDims(scope, newShape, axis, dType)); } @@ -290,7 +286,7 @@ public static Operand reduceDims( */ @Endpoint(name = "reduceDims") public static Operand reduceDims(Scope scope, Shape shape, Operand axis) { - return reduceDims(scope, shape, axis, TInt32.DTYPE); + return reduceDims(scope, shape, axis, TInt32.class); } /** @@ -305,7 +301,7 @@ public static Operand reduceDims(Scope scope, Shape shape, Opera */ @Endpoint(name = "reduceDims") public static Operand reduceDims( - Scope scope, Shape shape, Operand axis, DataType dType) { + Scope scope, Shape shape, Operand axis, Class dType) { Size rank = Size.create(scope, shape, dType); axis = FloorMod.create(scope, axis, rank); Sub remainder = Sub.create(scope, rank, axis); @@ -343,7 +339,7 @@ public static Operand reduceDims( */ @Endpoint(name = "squeeze") public static Operand squeeze(Scope scope, Shape shape) { - return squeeze(scope, shape, TInt32.DTYPE); + return squeeze(scope, shape, TInt32.class); } /** @@ -357,7 +353,7 @@ public static Operand squeeze(Scope scope, Shape shape) { */ @Endpoint(name = "squeeze") public static Operand squeeze( - Scope scope, Shape shape, DataType dType) { + Scope scope, Shape shape, Class dType) { Operand mask = NotEqual.create(scope, shape, Cast.create(scope, OnesLike.create(scope, shape), dType)); @@ -373,7 +369,7 @@ public static Operand squeeze( */ @Endpoint(name = "head") public static Operand head(Scope scope, Shape shape) { - return head(scope, shape, TInt32.DTYPE); + return head(scope, shape, TInt32.class); } /** @@ -387,7 +383,7 @@ public static Operand head(Scope scope, Shape shape) { */ @Endpoint(name = "head") public static Operand head( - Scope scope, Shape shape, DataType dType) { + Scope scope, Shape shape, Class dType) { return take(scope, shape, Cast.create(scope, Constant.scalarOf(scope, 1), dType), dType); } @@ -403,7 +399,7 @@ public static Operand head( */ @Endpoint(name = "take") public static Operand take(Scope scope, Shape shape, Operand n) { - return take(scope, shape, n, TInt32.DTYPE); + return take(scope, shape, n, TInt32.class); } /** @@ -420,7 +416,7 @@ public static Operand take(Scope scope, Shape shape, Operand Operand take( - Scope scope, Shape shape, Operand n, DataType dType) { + Scope scope, Shape shape, Operand n, Class dType) { return Slice.create( scope, shape, @@ -439,7 +435,7 @@ public static Operand take( */ @Endpoint(name = "tail") public static Operand tail(Scope scope, Shape shape) { - return tail(scope, shape, TInt32.DTYPE); + return tail(scope, shape, TInt32.class); } /** @@ -455,7 +451,7 @@ public static Operand tail(Scope scope, Shape shape) { */ @Endpoint(name = "tail") public static Operand tail( - Scope scope, Shape shape, DataType dType) { + Scope scope, Shape shape, Class dType) { return takeLast(scope, shape, Cast.create(scope, Constant.scalarOf(scope, 1), dType), dType); } @@ -472,7 +468,7 @@ public static Operand tail( @Endpoint(name = "takeLast") public static Operand takeLast( Scope scope, Shape shape, Operand n) { - return takeLast(scope, shape, n, TInt32.DTYPE); + return takeLast(scope, shape, n, TInt32.class); } /** @@ -489,7 +485,7 @@ public static Operand takeLast( */ @Endpoint(name = "takeLast") public static Operand takeLast( - Scope scope, Shape shape, Operand n, DataType dType) { + Scope scope, Shape shape, Operand n, Class dType) { Size rank = Size.create(scope, shape, dType); Sub start = Sub.create(scope, rank, n); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Zeros.java index 4aad417b117..3de8b1d26e8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Zeros.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Zeros.java @@ -14,7 +14,6 @@ ==============================================================================*/ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.Output; @@ -51,10 +50,10 @@ public final class Zeros implements Op, Operand { */ @Endpoint @SuppressWarnings("unchecked") - public static Zeros create(Scope scope, Operand dims, DataType type) { + public static Zeros create(Scope scope, Operand dims, Class type) { Scope zerosScope = scope.withSubScope("Zeros"); Operand zero; - if (type == TString.DTYPE) { + if (type == TString.class) { zero = (Operand)Constant.scalarOf(zerosScope.withName("Zero"), ""); } else { zero = Cast.create(zerosScope.withName("Zero"), Constant.scalarOf(zerosScope, 0), type); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java index 4f3e9569103..191e4b3c122 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java @@ -8,7 +8,13 @@ import org.tensorflow.op.core.Select; import org.tensorflow.op.core.ZerosLike; import org.tensorflow.op.dtypes.Cast; -import org.tensorflow.op.math.*; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Exp; +import org.tensorflow.op.math.GreaterEqual; +import org.tensorflow.op.math.Log1p; +import org.tensorflow.op.math.Mul; +import org.tensorflow.op.math.Neg; +import org.tensorflow.op.math.Sub; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; @@ -72,7 +78,7 @@ public static Operand sigmoidCrossEntropyWithLogits( scope = scope.withSubScope("SigmoidCrossEntropyWithLogits"); Operand zeros = - Cast.create(scope, ZerosLike.create(scope, logits), logits.asOutput().dataType()); + Cast.create(scope, ZerosLike.create(scope, logits), logits.asOutput().type()); Operand cond = GreaterEqual.create(scope, logits, zeros); Operand reluLogits = Select.create(scope, cond, logits, zeros); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java index 0c8bac697ed..d958c1272b1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -1,12 +1,18 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; +import java.util.Arrays; +import java.util.List; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; -import org.tensorflow.op.core.*; +import org.tensorflow.op.core.Concat; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Range; +import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.core.Slice; import org.tensorflow.op.dtypes.Cast; import org.tensorflow.op.linalg.Transpose; import org.tensorflow.op.math.Sub; @@ -15,10 +21,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; - -import java.util.Arrays; -import java.util.List; @Operator(group = "nn") public class SoftmaxCrossEntropyWithLogits { @@ -78,24 +80,20 @@ public static Operand softmaxCrossEntr axis += logits.asOutput().shape().numDimensions(); } - - boolean convertToFloat32 = - logits.asOutput().dataType() == TFloat16.DTYPE - || logits.asOutput().dataType() == TBfloat16.DTYPE; - if (convertToFloat32) { + if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { Operand result = softmaxCrossEntropyWithLogits(scope, - Cast.create(scope, labels, TFloat32.DTYPE), - Cast.create(scope, logits, TFloat32.DTYPE), + Cast.create(scope, labels, TFloat32.class), + Cast.create(scope, logits, TFloat32.class), axis); - return Cast.create(scope, result, logits.asOutput().dataType()); - } else if(!logits.asOutput().dataType().equals(labels.asOutput().dataType())) { + return Cast.create(scope, result, logits.asOutput().type()); + } else if(!logits.asOutput().type().equals(labels.asOutput().type())) { return softmaxCrossEntropyWithLogits(scope, - Cast.create(scope, labels, logits.asOutput().dataType()), + Cast.create(scope, labels, logits.asOutput().type()), logits, axis); } - Operand inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.DTYPE); + Operand inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.class); Shape shape = logits.asOutput().shape(); // Move the dim to the end if dim is not the last dimension. @@ -167,13 +165,13 @@ private static Operand flattenOuterDims(Scope scope, Oper } } - Operand rank = Cast.create(scope, Rank.create(scope, logits), TInt64.DTYPE); + Operand rank = Cast.create(scope, Rank.create(scope, logits), TInt64.class); Operand rankMinusOne = Sub.create(scope, rank, one); Operand lastDimSize = Slice.create( scope, - org.tensorflow.op.core.Shape.create(scope, logits, TInt64.DTYPE), + org.tensorflow.op.core.Shape.create(scope, logits, TInt64.class), rankMinusOne, one); Operand concat = @@ -197,7 +195,7 @@ private static Operand flattenOuterDims(Scope scope, Oper */ private static Operand moveDimToEnd( Scope scope, Operand input, int dimIndex, Operand rank) { - DataType rankDType = rank.asOutput().dataType(); + Class rankDType = rank.asOutput().type(); Operand one = Cast.create(scope, Constant.scalarOf(scope, 1), rankDType); List> concatList = Arrays.asList( diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index ebd6f74e7d8..79840827c13 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -1,5 +1,8 @@ package org.tensorflow.op.nn; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -18,10 +21,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - @Operator(group = "nn") public class SparseSoftmaxCrossEntropyWithLogits { @@ -74,11 +73,8 @@ public static Operand sparseSoftmaxCrossE scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits"); /** cannot use generics on preciseLogits as it may be recast later */ Operand preciseLogits = logits; - boolean convertToFloat32 = - logits.asOutput().dataType() == TFloat16.DTYPE - || logits.asOutput().dataType() == TBfloat16.DTYPE; - if (convertToFloat32) { - preciseLogits = Cast.create(scope, logits, TFloat32.DTYPE); + if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { + preciseLogits = Cast.create(scope, logits, TFloat32.class); } Shape labelsStaticShape = labels.asOutput().shape(); org.tensorflow.op.core.Shape labelsShape = @@ -115,8 +111,8 @@ public static Operand sparseSoftmaxCrossE org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create( scope, preciseLogits, labels); Operand loss = smax.loss(); - if (logits.asOutput().dataType() == TFloat16.DTYPE) { - loss = Cast.create(scope, loss, TFloat16.DTYPE); + if (logits.asOutput().type() == TFloat16.class) { + loss = Cast.create(scope, loss, TFloat16.class); } return loss; } @@ -153,8 +149,8 @@ public static Operand sparseSoftmaxCrossE scope, preciseLogits, labels); Operand cost = smax.loss(); cost = Reshape.create(scope, cost, labelsShape); - if (logits.asOutput().dataType() == TFloat16.DTYPE) { - cost = Cast.create(scope, cost, TFloat16.DTYPE); + if (logits.asOutput().type() == TFloat16.class) { + cost = Cast.create(scope, cost, TFloat16.class); } return cost; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java index 50f6ea49b06..6b56c310e47 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java @@ -18,18 +18,16 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; -import org.tensorflow.Tensor; +import org.tensorflow.Tensors; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.FloatDataBuffer; -import org.tensorflow.ndarray.buffer.layout.DataLayouts; +import org.tensorflow.internal.types.TBfloat16Factory; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; -import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TFloating; /** @@ -48,12 +46,8 @@ *

Note that some CPUs support the bfloat16 format natively, which can result in faster * computation compared to {@link TFloat16} when GPUs are not used. */ -public interface TBfloat16 extends FloatNdArray, TFloating { - /** readable-name for the data type */ - static final String NAME = "BFLOAT16"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 14, 2, TBfloat16Impl::mapTensor); +@TensorType(dataType = DataType.DT_BFLOAT16, byteSize = 2, factory = TBfloat16Factory.class) +public interface TBfloat16 extends TFloating, FloatNdArray { /** * Allocates a new tensor for storing a single float value. @@ -61,8 +55,8 @@ public interface TBfloat16 extends FloatNdArray, TFloating { * @param value float to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(float value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setFloat(value)); + static TBfloat16 scalarOf(float value) { + return Tensors.of(TBfloat16.class, Shape.scalar(), t -> t.setFloat(value)); } /** @@ -71,11 +65,11 @@ static Tensor scalarOf(float value) { * @param values floats to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(float... values) { + static TBfloat16 vectorOf(float... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensors.of(TBfloat16.class, Shape.of(values.length), t -> StdArrays.copyTo(values, t)); } /** @@ -86,8 +80,8 @@ static Tensor vectorOf(float... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + static TBfloat16 tensorOf(NdArray src) { + return Tensors.of(TBfloat16.class, src.shape(), src::copyTo); } /** @@ -96,8 +90,8 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + static TBfloat16 tensorOf(Shape shape) { + return Tensors.of(TBfloat16.class, shape); } /** @@ -107,32 +101,20 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of floats to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, FloatDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + static TBfloat16 tensorOf(Shape shape, FloatDataBuffer data) { + return Tensors.of(TBfloat16.class, shape, t -> t.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * * @param shape shape of the tensor to allocate - * @param dataInit tensor data initializer + * @param tensorInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); + static TBfloat16 tensorOf(Shape shape, Consumer tensorInit) { + return Tensors.of(TBfloat16.class, shape, tensorInit); } } -/** Hidden implementation of a {@code TBfloat16} */ -class TBfloat16Impl extends FloatDenseNdArray implements TBfloat16 { - - static TBfloat16 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TBfloat16Impl( - DataLayouts.BFLOAT16.applyTo(TensorBuffers.toShorts(nativeTensor)), shape); - } - - private TBfloat16Impl(FloatDataBuffer buffer, Shape shape) { - super(buffer, shape); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java index 3cc72101893..f5e5976bc73 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java @@ -17,22 +17,20 @@ package org.tensorflow.types; -import org.tensorflow.DataType; -import org.tensorflow.Tensor; +import java.util.function.Consumer; +import org.tensorflow.Tensors; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.TBoolFactory; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.BooleanDataBuffer; import org.tensorflow.ndarray.buffer.layout.DataLayouts; -import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TType; -import java.util.function.Consumer; - /** * Boolean tensor type. * @@ -40,12 +38,8 @@ * explicit mapping between Java boolean values and byte buffers using the {@link DataLayouts#BOOL * BOOL} layout, which may impact I/O performances. */ -public interface TBool extends BooleanNdArray, TType { - /** readable-name for the data type */ - static final String NAME = "BOOL"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 10, 1, TBoolImpl::mapTensor); +@TensorType(dataType = DataType.DT_BOOL, byteSize = 1, factory = TBoolFactory.class) +public interface TBool extends TType, BooleanNdArray { /** * Allocates a new tensor for storing a single boolean value. @@ -53,8 +47,8 @@ public interface TBool extends BooleanNdArray, TType { * @param value boolean to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(boolean value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setBoolean(value)); + static TBool scalarOf(boolean value) { + return Tensors.of(TBool.class, Shape.scalar(), t -> t.setBoolean(value)); } /** @@ -63,11 +57,11 @@ static Tensor scalarOf(boolean value) { * @param values booleans to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(boolean... values) { + static TBool vectorOf(boolean... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensors.of(TBool.class, Shape.of(values.length), t -> StdArrays.copyTo(values, t)); } /** @@ -78,8 +72,8 @@ static Tensor vectorOf(boolean... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + static TBool tensorOf(NdArray src) { + return Tensors.of(TBool.class, src.shape(), src::copyTo); } /** @@ -88,8 +82,8 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + static TBool tensorOf(Shape shape) { + return Tensors.of(TBool.class, shape); } /** @@ -99,31 +93,20 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of booleans to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, BooleanDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + static TBool tensorOf(Shape shape, BooleanDataBuffer data) { + return Tensors.of(TBool.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * * @param shape shape of the tensor to allocate - * @param dataInit tensor data initializer + * @param tensorInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); + static TBool tensorOf(Shape shape, Consumer tensorInit) { + return Tensors.of(TBool.class, shape, tensorInit); } -} -/** Hidden implementation of a {@code TBool} */ -class TBoolImpl extends BooleanDenseNdArray implements TBool { - - static TBool mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TBoolImpl(TensorBuffers.toBooleans(nativeTensor), shape); - } - - private TBoolImpl(BooleanDataBuffer buffer, Shape shape) { - super(buffer, shape); - } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java index 0cd441a1ff1..7add4b25346 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java @@ -18,18 +18,16 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; -import org.tensorflow.Tensor; +import org.tensorflow.Tensors; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.FloatDataBuffer; -import org.tensorflow.ndarray.buffer.layout.DataLayouts; +import org.tensorflow.internal.types.TFloat16Factory; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; -import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TFloating; /** @@ -45,13 +43,8 @@ * most CPUs do not support this format natively. For CPU computation on 16-bit floats, the {@link * TBfloat16} tensor type might be a better option. */ -public interface TFloat16 extends FloatNdArray, TFloating { - - /** readable-name for the data type */ - static final String NAME = "FLOAT16"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 19, 2, TFloat16Impl::mapTensor); +@TensorType(dataType = DataType.DT_HALF, byteSize = 2, factory = TFloat16Factory.class) +public interface TFloat16 extends TFloating, FloatNdArray { /** * Allocates a new tensor for storing a single float value. @@ -59,8 +52,8 @@ public interface TFloat16 extends FloatNdArray, TFloating { * @param value float to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(float value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setFloat(value)); + static TFloat16 scalarOf(float value) { + return Tensors.of(TFloat16.class, Shape.scalar(), t -> t.setFloat(value)); } /** @@ -69,11 +62,11 @@ static Tensor scalarOf(float value) { * @param values floats to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(float... values) { + static TFloat16 vectorOf(float... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensors.of(TFloat16.class, Shape.of(values.length), t -> StdArrays.copyTo(values, t)); } /** @@ -84,8 +77,8 @@ static Tensor vectorOf(float... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + static TFloat16 tensorOf(NdArray src) { + return Tensors.of(TFloat16.class, src.shape(), src::copyTo); } /** @@ -94,8 +87,8 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + static TFloat16 tensorOf(Shape shape) { + return Tensors.of(TFloat16.class, shape); } /** @@ -105,32 +98,20 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of floats to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, FloatDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + static TFloat16 tensorOf(Shape shape, FloatDataBuffer data) { + return Tensors.of(TFloat16.class, shape, t -> t.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * * @param shape shape of the tensor to allocate - * @param dataInit tensor data initializer + * @param tensorInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); + static TFloat16 tensorOf(Shape shape, Consumer tensorInit) { + return Tensors.of(TFloat16.class, shape, tensorInit); } -} - -/** Hidden implementation of a {@code TFloat16} */ -class TFloat16Impl extends FloatDenseNdArray implements TFloat16 { - static TFloat16 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TFloat16Impl( - DataLayouts.FLOAT16.applyTo(TensorBuffers.toShorts(nativeTensor)), shape); - } - - private TFloat16Impl(FloatDataBuffer buffer, Shape shape) { - super(buffer, shape); - } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java index 571ec118ddc..32bc85ffa3e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java @@ -18,27 +18,23 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; -import org.tensorflow.Tensor; +import org.tensorflow.Tensors; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.internal.types.TFloat32Factory; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; -import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TFloating; -/** IEEE-754 single-precision 32-bit float tensor type. */ -public interface TFloat32 extends FloatNdArray, TFloating { - - /** readable-name for the data type */ - static final String NAME = "FLOAT"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 1, 4, TFloat32Impl::mapTensor); +/** + * IEEE-754 single-precision 32-bit float tensor type. + */ +@TensorType(dataType = DataType.DT_FLOAT, byteSize = 4, factory = TFloat32Factory.class) +public interface TFloat32 extends TFloating, FloatNdArray { /** * Allocates a new tensor for storing a single float value. @@ -46,8 +42,8 @@ public interface TFloat32 extends FloatNdArray, TFloating { * @param value float to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(float value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setFloat(value)); + static TFloat32 scalarOf(float value) { + return Tensors.of(TFloat32.class, Shape.scalar(), t -> t.setFloat(value)); } /** @@ -56,11 +52,11 @@ static Tensor scalarOf(float value) { * @param values floats to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(float... values) { + static TFloat32 vectorOf(float... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensors.of(TFloat32.class, Shape.of(values.length), t -> StdArrays.copyTo(values, t)); } /** @@ -71,8 +67,8 @@ static Tensor vectorOf(float... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + static TFloat32 tensorOf(NdArray src) { + return Tensors.of(TFloat32.class, src.shape(), src::copyTo); } /** @@ -81,8 +77,8 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + static TFloat32 tensorOf(Shape shape) { + return Tensors.of(TFloat32.class, shape); } /** @@ -92,31 +88,20 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of floats to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, FloatDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + static TFloat32 tensorOf(Shape shape, FloatDataBuffer data) { + return Tensors.of(TFloat32.class, shape, t -> t.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * * @param shape shape of the tensor to allocate - * @param dataInit tensor data initializer + * @param tensorInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); - } -} - -/** Hidden implementation of a {@code TFloat32} */ -class TFloat32Impl extends FloatDenseNdArray implements TFloat32 { - - static TFloat32 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TFloat32Impl(TensorBuffers.toFloats(nativeTensor), shape); + static TFloat32 tensorOf(Shape shape, Consumer tensorInit) { + return Tensors.of(TFloat32.class, shape, tensorInit); } - private TFloat32Impl(FloatDataBuffer buffer, Shape shape) { - super(buffer, shape); - } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java index 5d2744c4b3c..cc5799bf27f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java @@ -18,28 +18,23 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; -import org.tensorflow.Tensor; +import org.tensorflow.Tensors; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.internal.types.TFloat64Factory; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; -import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TFloating; - -/** IEEE-754 double-precision 64-bit float tensor type. */ -public interface TFloat64 extends DoubleNdArray, TFloating { - - /** readable-name for the data type */ - static final String NAME = "DOUBLE"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 2, 8, TFloat64Impl::mapTensor); +/** + * IEEE-754 double-precision 64-bit float tensor type. + */ +@TensorType(dataType = DataType.DT_DOUBLE, byteSize = 8, factory = TFloat64Factory.class) +public interface TFloat64 extends TFloating, DoubleNdArray { /** * Allocates a new tensor for storing a single double value. @@ -47,8 +42,8 @@ public interface TFloat64 extends DoubleNdArray, TFloating { * @param value double to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(double value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setDouble(value)); + static TFloat64 scalarOf(double value) { + return Tensors.of(TFloat64.class, Shape.scalar(), t -> t.setDouble(value)); } /** @@ -57,11 +52,11 @@ static Tensor scalarOf(double value) { * @param values doubles to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(double... values) { + static TFloat64 vectorOf(double... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensors.of(TFloat64.class, Shape.of(values.length), t -> StdArrays.copyTo(values, t)); } /** @@ -72,8 +67,8 @@ static Tensor vectorOf(double... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + static TFloat64 tensorOf(NdArray src) { + return Tensors.of(TFloat64.class, src.shape(), src::copyTo); } /** @@ -82,8 +77,8 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + static TFloat64 tensorOf(Shape shape) { + return Tensors.of(TFloat64.class, shape); } /** @@ -93,31 +88,19 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of doubles to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, DoubleDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + static TFloat64 tensorOf(Shape shape, DoubleDataBuffer data) { + return Tensors.of(TFloat64.class, shape, t -> t.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * * @param shape shape of the tensor to allocate - * @param dataInit tensor data initializer + * @param tensorInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); - } -} - -/** Hidden implementation of a {@code TFloat64} */ -class TFloat64Impl extends DoubleDenseNdArray implements TFloat64 { - - static TFloat64 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TFloat64Impl(TensorBuffers.toDoubles(nativeTensor), shape); - } - - private TFloat64Impl(DoubleDataBuffer buffer, Shape shape) { - super(buffer, shape); + static TFloat64 tensorOf(Shape shape, Consumer tensorInit) { + return Tensors.of(TFloat64.class, shape, tensorInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java index 4a1139ddde2..ff2f7277ab8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java @@ -18,26 +18,23 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; -import org.tensorflow.Tensor; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.Tensors; +import org.tensorflow.internal.types.TInt32Factory; import org.tensorflow.ndarray.IntNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; -import org.tensorflow.ndarray.impl.dense.IntDenseNdArray; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; +import org.tensorflow.types.family.TInteger; import org.tensorflow.types.family.TNumber; -/** 32-bit signed integer tensor type. */ -public interface TInt32 extends IntNdArray, TNumber { - - /** readable-name for the data type */ - static final String NAME = "INT32"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 3, 4, TInt32Impl::mapTensor); +/** + * 32-bit signed integer tensor type. + */ +@TensorType(dataType = DataType.DT_INT32, byteSize = 4, factory = TInt32Factory.class) +public interface TInt32 extends TInteger, IntNdArray { /** * Allocates a new tensor for storing a single int value. @@ -45,8 +42,8 @@ public interface TInt32 extends IntNdArray, TNumber { * @param value int to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(int value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setInt(value)); + static TInt32 scalarOf(int value) { + return Tensors.of(TInt32.class, Shape.scalar(), t -> t.setInt(value)); } /** @@ -56,11 +53,11 @@ static Tensor scalarOf(int value) { * @return the new tensor * @throws IllegalArgumentException if no values are provided */ - static Tensor vectorOf(int... values) { + static TInt32 vectorOf(int... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensors.of(TInt32.class, Shape.of(values.length), t -> StdArrays.copyTo(values, t)); } /** @@ -71,8 +68,8 @@ static Tensor vectorOf(int... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + static TInt32 tensorOf(NdArray src) { + return Tensors.of(TInt32.class, src.shape(), src::copyTo); } /** @@ -81,8 +78,8 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + static TInt32 tensorOf(Shape shape) { + return Tensors.of(TInt32.class, shape); } /** @@ -92,30 +89,18 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of ints to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, IntDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + static TInt32 tensorOf(Shape shape, IntDataBuffer data) { + return Tensors.of(TInt32.class, shape, t -> t.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * * @param shape shape of the tensor to allocate - * @param dataInit tensor data initializer + * @param tensorInit tensor data initializer * @return the new tensor */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); - } -} - -/** Hidden implementation of a {@code TInt32} */ -class TInt32Impl extends IntDenseNdArray implements TInt32 { - - static TInt32 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TInt32Impl(TensorBuffers.toInts(nativeTensor), shape); - } - - private TInt32Impl(IntDataBuffer buffer, Shape shape) { - super(buffer, shape); + static TInt32 tensorOf(Shape shape, Consumer tensorInit) { + return Tensors.of(TInt32.class, shape, tensorInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java index 04fd4fd7799..cbefb248454 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java @@ -18,27 +18,24 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; -import org.tensorflow.Tensor; +import org.tensorflow.Tensors; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.internal.types.TInt64Factory; import org.tensorflow.ndarray.LongNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; -import org.tensorflow.ndarray.impl.dense.LongDenseNdArray; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; +import org.tensorflow.types.family.TInteger; import org.tensorflow.types.family.TNumber; -/** 64-bit signed integer tensor type. */ -public interface TInt64 extends LongNdArray, TNumber { - - /** readable-name for the data type */ - static final String NAME = "INT64"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 9, 8, TInt64Impl::mapTensor); +/** + * 64-bit signed integer tensor type. + */ +@TensorType(dataType = DataType.DT_INT64, byteSize = 8, factory = TInt64Factory.class) +public interface TInt64 extends TInteger, LongNdArray { /** * Allocates a new tensor for storing a single long value. @@ -46,8 +43,8 @@ public interface TInt64 extends LongNdArray, TNumber { * @param value long to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(long value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setLong(value)); + static TInt64 scalarOf(long value) { + return Tensors.of(TInt64.class, Shape.scalar(), t -> t.setLong(value)); } /** @@ -56,11 +53,11 @@ static Tensor scalarOf(long value) { * @param values longs to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(long... values) { + static TInt64 vectorOf(long... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensors.of(TInt64.class, Shape.of(values.length), t -> StdArrays.copyTo(values, t)); } /** @@ -71,8 +68,8 @@ static Tensor vectorOf(long... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + static TInt64 tensorOf(NdArray src) { + return Tensors.of(TInt64.class, src.shape(), src::copyTo); } /** @@ -81,8 +78,8 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + static TInt64 tensorOf(Shape shape) { + return Tensors.of(TInt64.class, shape); } /** @@ -92,31 +89,19 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of longs to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, LongDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + static TInt64 tensorOf(Shape shape, LongDataBuffer data) { + return Tensors.of(TInt64.class, shape, t -> t.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * * @param shape shape of the tensor to allocate - * @param dataInit tensor data initializer + * @param tensorInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); - } -} - -/** Hidden implementation of a {@code TInt64} */ -class TInt64Impl extends LongDenseNdArray implements TInt64 { - - static TInt64 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TInt64Impl(TensorBuffers.toLongs(nativeTensor), shape); - } - - private TInt64Impl(LongDataBuffer buffer, Shape shape) { - super(buffer, shape); + static TInt64 tensorOf(Shape shape, Consumer tensorInit) { + return Tensors.of(TInt64.class, shape, tensorInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java index 57a121edcf1..8c29ffb5aa6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java @@ -17,24 +17,20 @@ package org.tensorflow.types; -import org.tensorflow.DataType; -import org.tensorflow.Tensor; -import org.tensorflow.internal.buffer.StringTensorBuffer; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.function.Function; +import org.tensorflow.Tensors; +import org.tensorflow.internal.buffer.ByteSequenceTensorBuffer; +import org.tensorflow.internal.types.TStringFactory; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffer; -import org.tensorflow.ndarray.buffer.layout.DataLayout; -import org.tensorflow.ndarray.buffer.layout.DataLayouts; -import org.tensorflow.ndarray.impl.dense.DenseNdArray; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TType; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; -import java.util.function.Function; - /** * String type. * @@ -44,13 +40,8 @@ * its values initially, so TensorFlow can compute and allocate the right amount of memory. Then the * data in the tensor is initialized once and cannot be modified afterwards. */ -public interface TString extends NdArray, TType { - - /** readable-name for the data type */ - static final String NAME = "STRING"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 7, -1, TStringImpl::mapTensor); +@TensorType(dataType = DataType.DT_STRING, byteSize = -1, factory = TStringFactory.class) +public interface TString extends TType, NdArray { /** * Allocates a new tensor for storing a string scalar. @@ -60,7 +51,7 @@ public interface TString extends NdArray, TType { * @param value scalar value to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(String value) { + static TString scalarOf(String value) { return tensorOf(NdArrays.scalarOfObject(value)); } @@ -72,7 +63,7 @@ static Tensor scalarOf(String value) { * @param values values to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(String... values) { + static TString vectorOf(String... values) { if (values == null) { throw new IllegalArgumentException(); } @@ -88,7 +79,7 @@ static Tensor vectorOf(String... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { + static TString tensorOf(NdArray src) { return tensorOf(StandardCharsets.UTF_8, src); } @@ -113,8 +104,10 @@ static Tensor tensorOf(NdArray src) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(Charset charset, NdArray src) { - return TStringImpl.createTensor(src, s -> s.getBytes(charset)); + static TString tensorOf(Charset charset, NdArray src) { + Function getBytes = s -> s.getBytes(charset); + long size = ByteSequenceTensorBuffer.computeSize(src, getBytes); + return Tensors.of(TString.class, src.shape(), size, t -> t.write(src, getBytes)); } /** @@ -127,7 +120,7 @@ static Tensor tensorOf(Charset charset, NdArray src) { * @param data buffer of strings to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, DataBuffer data) { + static TString tensorOf(Shape shape, DataBuffer data) { return tensorOf(NdArrays.wrap(shape, data)); } @@ -154,7 +147,7 @@ static Tensor tensorOf(Shape shape, DataBuffer data) { * @param data buffer of strings to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Charset charset, Shape shape, DataBuffer data) { + static TString tensorOf(Charset charset, Shape shape, DataBuffer data) { return tensorOf(charset, NdArrays.wrap(shape, data)); } @@ -173,8 +166,10 @@ static Tensor tensorOf(Charset charset, Shape shape, DataBuffer * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOfBytes(NdArray src) { - return TStringImpl.createTensor(src, Function.identity()); + static TString tensorOfBytes(NdArray src) { + Function getBytes = Function.identity(); + long size = ByteSequenceTensorBuffer.computeSize(src, getBytes); + return Tensors.of(TString.class, src.shape(), size, t -> t.write(src, getBytes)); } /** @@ -193,7 +188,7 @@ static Tensor tensorOfBytes(NdArray src) { * @param data the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOfBytes(Shape shape, DataBuffer data) { + static TString tensorOfBytes(Shape shape, DataBuffer data) { return tensorOfBytes(NdArrays.wrap(shape, data)); } @@ -215,45 +210,7 @@ static Tensor tensorOfBytes(Shape shape, DataBuffer data) { */ TString using(Charset charset); - /** @return the tensor data as a n-dimensional array of raw byte sequences. */ - NdArray asBytes(); -} - -/** Hidden implementation of a {@code TString} */ -class TStringImpl extends DenseNdArray implements TString { + void write(NdArray src, Function getBytes); - @Override - public TString using(Charset charset) { - return new TStringImpl(tensorBuffer, DataLayouts.ofStrings(charset), shape()); - } - - @Override - public NdArray asBytes() { - return NdArrays.wrap(shape(), tensorBuffer); - } - - static Tensor createTensor(NdArray src, Function getBytes) { - long size = StringTensorBuffer.computeSize(src, getBytes); - return Tensor.of( - TString.DTYPE, - src.shape(), - size, - data -> ((TStringImpl) data).tensorBuffer.init(src, getBytes)); - } - - static TString mapTensor(TF_Tensor nativeTensor, Shape shape) { - StringTensorBuffer buffer = TensorBuffers.toStrings(nativeTensor, shape.size()); - return new TStringImpl(buffer, UTF_8_LAYOUT, shape); - } - - private static DataLayout, String> UTF_8_LAYOUT = - DataLayouts.ofStrings(StandardCharsets.UTF_8); - - private final StringTensorBuffer tensorBuffer; - - private TStringImpl( - StringTensorBuffer buffer, DataLayout, String> layout, Shape shape) { - super(layout.applyTo(buffer), shape); - tensorBuffer = buffer; - } + NdArray asBytes(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java index 365f41196fb..c92a2db560e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java @@ -18,27 +18,24 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; -import org.tensorflow.Tensor; +import org.tensorflow.Tensors; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.internal.types.TUint8Factory; import org.tensorflow.ndarray.ByteNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; -import org.tensorflow.ndarray.impl.dense.ByteDenseNdArray; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; +import org.tensorflow.types.family.TInteger; import org.tensorflow.types.family.TNumber; -/** 8-bit unsigned integer tensor type. */ -public interface TUint8 extends ByteNdArray, TNumber { - - /** readable-name for the data type */ - static final String NAME = "UINT8"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 4, 1, TUint8Impl::mapTensor); +/** + * 8-bit unsigned integer tensor type. + */ +@TensorType(dataType = DataType.DT_UINT8, byteSize = 1, factory = TUint8Factory.class) +public interface TUint8 extends TInteger, ByteNdArray { /** * Allocates a new tensor for storing a single byte value. @@ -46,8 +43,8 @@ public interface TUint8 extends ByteNdArray, TNumber { * @param value byte to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(byte value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setByte(value)); + static TUint8 scalarOf(byte value) { + return Tensors.of(TUint8.class, Shape.scalar(), t -> t.setByte(value)); } /** @@ -56,11 +53,11 @@ static Tensor scalarOf(byte value) { * @param values bytes to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(byte... values) { + static TUint8 vectorOf(byte... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensors.of(TUint8.class, Shape.of(values.length), t -> StdArrays.copyTo(values, t)); } /** @@ -71,8 +68,8 @@ static Tensor vectorOf(byte... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + static TUint8 tensorOf(NdArray src) { + return Tensors.of(TUint8.class, src.shape(), src::copyTo); } /** @@ -81,8 +78,8 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + static TUint8 tensorOf(Shape shape) { + return Tensors.of(TUint8.class, shape); } /** @@ -92,31 +89,19 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of bytes to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, ByteDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + static TUint8 tensorOf(Shape shape, ByteDataBuffer data) { + return Tensors.of(TUint8.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * * @param shape shape of the tensor to allocate - * @param dataInit tensor data initializer + * @param tensorInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); - } -} - -/** Hidden implementation of a {@code TUint8} */ -class TUint8Impl extends ByteDenseNdArray implements TUint8 { - - static TUint8 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TUint8Impl(TensorBuffers.toBytes(nativeTensor), shape); - } - - private TUint8Impl(ByteDataBuffer buffer, Shape shape) { - super(buffer, shape); + static TUint8 tensorOf(Shape shape, Consumer tensorInit) { + return Tensors.of(TUint8.class, shape, tensorInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/Type.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/Type.java new file mode 100644 index 00000000000..499f393d7cc --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/Type.java @@ -0,0 +1,39 @@ +package org.tensorflow.types; + +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.family.TType; + +public class Type { + + public Class typeClass() { + return typeClass; + } + + public DataType dataType() { + return dataType; + } + + public int byteSize() { + return byteSize; + } + + public boolean isVariableLength() { + return byteSize < 0; + } + + public TypeFactory factory() { + return factory; + } + + Type(Class typeClass, DataType dataType, int byteSize, TypeFactory factory) { + this.typeClass = typeClass; + this.dataType = dataType; + this.byteSize = byteSize; + this.factory = factory; + } + + private final Class typeClass; + private final DataType dataType; + private final int byteSize; + private final TypeFactory factory; +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TypeFactory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TypeFactory.java new file mode 100644 index 00000000000..4529657003e --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TypeFactory.java @@ -0,0 +1,10 @@ +package org.tensorflow.types; + +import org.tensorflow.TensorHandle; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.family.TType; + +public interface TypeFactory { + + T createDense(TensorHandle tensorHandle, Shape shape); +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TypeRegistry.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TypeRegistry.java new file mode 100644 index 00000000000..99e9170fa55 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TypeRegistry.java @@ -0,0 +1,98 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.types; + +import java.util.HashMap; +import java.util.Map; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; +import org.tensorflow.types.family.TType; + +/** + * Utility class for working with {@link DataType} objects. + */ +public final class TypeRegistry { + + /** + * Find a type registration from a data type + * + * @param dataType data type + * @return type registered to this code + * @throws IllegalArgumentException if the code matches no registered data type + */ + public static Type find(DataType dataType) { + Type entry = TYPES_BY_CODE.get(dataType.getNumber()); + if (entry == null) { + throw new IllegalArgumentException("No type has been registered for datatype " + dataType); + } + return (Type)entry; + } + + /** + * Find a type registration from a type class + * + * @param typeClass class implementing {@link Tensor} + * @return type registration + * @throws IllegalArgumentException if the code matches no registered data type + */ + public static Type find(Class typeClass) { + Type entry = TYPES_BY_CLASS.get(typeClass); + if (entry == null) { + throw new IllegalArgumentException("Class \"" + typeClass.getName() + "\" is not a valid datatype class"); + } + return (Type)entry; + } + + private static final Map> TYPES_BY_CODE = new HashMap<>(); + private static final Map, Type> TYPES_BY_CLASS = new HashMap<>(); + + private static void register(Class typeClass) { + TensorType typeAnnot = typeClass.getDeclaredAnnotation(TensorType.class); + if (typeAnnot == null) { + throw new IllegalArgumentException("Class \"" + typeClass.getName() + "\" must be annotated " + + "with @TensorType to be registered as a tensor type"); + } + TypeFactory factory; + try { + factory = typeAnnot.factory().newInstance(); + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException("Class \"" + typeClass.getName() + "\" must have a public " + + "parameter-less constructor to be used as a tensor type factory"); + } + Type type = new Type(typeClass, typeAnnot.dataType(), typeAnnot.byteSize(), factory); + TYPES_BY_CLASS.put(typeClass, type); + + // If more than one tensor type is mapped to a given native code, the last registered will + // have priority. This way, we can allow user to register their own classes to map tensors + // of a given data type. + TYPES_BY_CODE.put(type.dataType().getNumber(), type); + TYPES_BY_CODE.put(type.dataType().getNumber() + 100, type); + } + + static { + register(TBool.class); + register(TFloat64.class); + register(TFloat32.class); + register(TFloat16.class); + register(TInt32.class); + register(TInt64.class); + register(TString.class); + register(TUint8.class); + register(TBfloat16.class); + } +} \ No newline at end of file diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/annotation/TensorType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/annotation/TensorType.java new file mode 100644 index 00000000000..b5da7ddcb35 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/annotation/TensorType.java @@ -0,0 +1,40 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.types.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.TypeFactory; + +/** Represents a type of elements in a {@link Tensor} */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface TensorType { + + DataType dataType(); + + int byteSize(); + + /** + * The class implementing this tensor type + */ + Class factory(); +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TInteger.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TInteger.java new file mode 100644 index 00000000000..0112d278e49 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TInteger.java @@ -0,0 +1,5 @@ +package org.tensorflow.types.family; + +public interface TInteger extends TNumber { + +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java index 8f3451b9a68..8202d22f570 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java @@ -1,37 +1,31 @@ -/* - * Copyright 2019 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ======================================================================= - */ - package org.tensorflow.types.family; -/** - * Marker interface for all tensor types. - * - *

Tensor types are carried as a generic parameter of the {@link org.tensorflow.Tensor Tensor} - * class bound by the {@code TType} interface. This generic parameter ensure type-compatibility - * between operands of a computation at compile-time. For example: - * - *

{@code
- * Tensor tensor1 = TFloat32.ofShape(2, 3, 2);
- * Tensor tensor2 = TFloat32.ofShape(2, 3, 2);
- * Tensor tensor3 = TInt32.ofShape(2, 3, 2);
- *
- * Ops tf = Ops.create();
- * tf.math.add(tf.constant(tensor1), tf.constant(tensor2));  // OK
- * tf.math.add(tf.constant(tensor1), tf.constant(tensor3));  // Compilation failure
- * }
- */ -public interface TType {} +import org.tensorflow.Tensor; +import org.tensorflow.TensorHandle; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.proto.framework.DataType; + +public interface TType extends Tensor { + + Class type(); + + @Override + default void close() { + handle().close(); + } + + @Override + default ByteDataBuffer rawData() { + return handle().rawData(); + } + + @Override + default DataType dataType() { + return handle().dataType(); + } + + @Override + default long numBytes() { + return handle().numBytes(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/package-info.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/package-info.java index fcf1f554133..79e9e93fea3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/package-info.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/package-info.java @@ -21,8 +21,5 @@ *

Some operations enforces that only operands of a type from a given family can be passed * in argument. For example, if an operation only allows numeric operands, such operands must be * bound to the {@link org.tensorflow.types.family.TNumber TNumber} interface. - * - *

All tensor types is bound to {@link org.tensorflow.types.family.TType TType}, which lays at - * the root of the family hierarchy. */ package org.tensorflow.types.family; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/package-info.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/package-info.java index afbd69fabe5..b41ce514f00 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/package-info.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/package-info.java @@ -32,10 +32,10 @@ * *

Instances of tensor types must also implement the {@link org.tensorflow.ndarray.NdArray NdArray} * interface so a user can access directly the tensor data in a n-dimensional space by invoking - * {@link org.tensorflow.Tensor#data() Tensor.data()}. + * {@link org.tensorflow.util.Tensor#data() Tensor.data()}. * *

Note that while it is always possible to allocate a tensor using the - * {@link org.tensorflow.Tensor#of(org.tensorflow.DataType, Shape) Tensor.of(...)} + * {@link org.tensorflow.util.Tensor#of(org.tensorflow.DataType, Shape) Tensor.of(...)} * method, most tensor types expose factory methods that simplify the creation process, like * {@code scalarOf(...)}, {@code vectorOf(...)}, {@code tensorOf(...)}, etc. */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/util/TensorList.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/util/TensorList.java new file mode 100644 index 00000000000..87a4fb378c9 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/util/TensorList.java @@ -0,0 +1,60 @@ +package org.tensorflow.util; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; +import java.util.stream.Stream; +import org.tensorflow.types.family.TType; + +public class TensorList implements AutoCloseable { + + public boolean isEmpty() { + return tensors.isEmpty(); + } + + public int size() { + return tensors.size(); + } + + public Iterator iterator() { + return tensors.iterator(); + } + + public ListIterator listIterator() { + return tensors.listIterator(); + } + + public Stream stream() { + return tensors.stream(); + } + + public boolean add(T tensor) { + return tensors.add(tensor); + } + + public T single() { + if (tensors.size() != 1) { + throw new IllegalStateException("List must contain a single tensor to use non-indexed getter"); + } + return get(0); + } + + public T get(int index) { + return (T)tensors.get(index); + } + + public TensorMap toMap(Collection tensorNames) { + return new TensorMap(tensorNames, this); + } + + @Override + public void close() { + for (TType t : tensors) { + t.handle().release(); + } + } + + private final List tensors = new ArrayList<>(); +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/util/TensorMap.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/util/TensorMap.java new file mode 100644 index 00000000000..7e821ccf2af --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/util/TensorMap.java @@ -0,0 +1,72 @@ +package org.tensorflow.util; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import org.tensorflow.types.family.TType; + +public class TensorMap extends HashMap implements AutoCloseable { + + @Override + public TType put(String key, TType value) { + value.handle().retain(); + TType t = super.put(key, value); + if (t != null) { + t.handle().release(); + } + return t; + } + + @Override + public TType remove(Object key) { + TType t = super.remove(key); + if (t != null) { + t.handle().release(); + } + return t; + } + + @Override + public void putAll(Map m) { + super.putAll(m); + for (TType t : m.values()) { + t.handle().retain(); + } + } + + @Override + public void clear() { + for (TType t : values()) { + t.handle().release(); + } + super.clear(); + } + + @Override + public void close() { + clear(); + } + + public T get(String name) { + return (T) super.get(name); + } + + public T single() { + if (size() != 1) { + throw new IllegalStateException("List must contain a single tensor to use non-indexed getter"); + } + return (T) values().iterator().next(); + } + + public TensorMap() { + super(); + } + + public TensorMap(Collection names, TensorList tensors) { + Iterator tensorIter = tensors.iterator(); + for (String name : names) { + put(name, tensorIter.next()); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java deleted file mode 100644 index 330a40bae6b..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.tensorflow; - -import java.util.ArrayList; -import java.util.Collection; - -public final class AutoCloseableList extends ArrayList - implements AutoCloseable { - - public AutoCloseableList(Collection c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index 3ea20fcbb46..dd17c891e69 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -29,14 +29,14 @@ public class ConcreteFunctionTest { private static Signature plusFive(Ops tf) { - Placeholder input = tf.placeholder(TFloat32.DTYPE); + Placeholder input = tf.placeholder(TFloat32.class); Add output = tf.math.add(input, tf.constant(5.0f)); Init init = tf.init(); // for native resource management tests return Signature.builder().key("plusFive").input("x", input).output("y", output).build(); } private static Signature minusTwo(Ops tf) { - Placeholder input = tf.placeholder(TFloat32.DTYPE); + Placeholder input = tf.placeholder(TFloat32.class); Sub output = tf.math.sub(input, tf.constant(2.0f)); return Signature.builder().key("minusTwo").input("x", input).output("y", output).build(); } @@ -44,8 +44,8 @@ private static Signature minusTwo(Ops tf) { @Test public void createFunction() { try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive); - Tensor x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + TFloat32 x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); } } @@ -54,8 +54,8 @@ public void createFunctionFromGraph() { try (Graph g = new Graph()) { Signature signature = plusFive(Ops.create(g)); try (ConcreteFunction f = ConcreteFunction.create(signature, g); - Tensor x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + TFloat32 x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); } } } @@ -66,8 +66,8 @@ public void createFunctionFromSession() { Signature signature = plusFive(Ops.create(g)); try (Session s = new Session(g)) { try (ConcreteFunction f = ConcreteFunction.create(signature, s); - Tensor x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + TFloat32 x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); } } } @@ -77,8 +77,9 @@ public void createFunctionFromSession() { public void chainFunctions() { try (ConcreteFunction f1 = ConcreteFunction.create(ConcreteFunctionTest::plusFive); ConcreteFunction f2 = ConcreteFunction.create(ConcreteFunctionTest::minusTwo); - Tensor x = TFloat32.scalarOf(3.0f)) { - assertEquals(6.0f, f2.call(f1.call(x)).expect(TFloat32.DTYPE).data().getFloat()); + TFloat32 x = TFloat32.scalarOf(3.0f)) { + TFloat32 y = f1.call(x); + assertEquals(6.0f, ((TFloat32)f2.call(y)).getFloat()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java index 6802ead9592..fda6084b796 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java @@ -45,7 +45,7 @@ public void failToBuildOpIfSessionIsClosed() { opBuilder = new EagerOperationBuilder(session, "Empty", "empty"); } try { - opBuilder.setAttr("dtype", TFloat32.DTYPE); + opBuilder.setAttr("dtype", TFloat32.class); fail(); } catch (IllegalStateException e) { // expected @@ -93,9 +93,9 @@ public void setAttrs() { try (EagerSession session = EagerSession.create()) { Ops tf = Ops.create(session); // dtype, tensor attributes. - try (Tensor t = TInt32.scalarOf(1)) { + try (TInt32 t = TInt32.scalarOf(1)) { opBuilder(session, "Const", "DataTypeAndTensor") - .setAttr("dtype", TInt32.DTYPE) + .setAttr("dtype", TInt32.class) .setAttr("value", t) .build(); } @@ -103,7 +103,7 @@ public void setAttrs() { opBuilder(session, "RandomUniform", "DataTypeAndInt") .addInput(tf.array(1).asOutput()) .setAttr("seed", 10) - .setAttr("dtype", TFloat32.DTYPE) + .setAttr("dtype", TFloat32.class) .build(); // list(int), string opBuilder(session, "MaxPool", "IntListAndString") @@ -124,7 +124,7 @@ public void setAttrs() { .build(); // list(shape) opBuilder(session, "FIFOQueue", "queue") - .setAttr("component_types", new DataType[] {TInt32.DTYPE, TInt32.DTYPE}) + .setAttr("component_types", new Class[] {TInt32.class, TInt32.class}) .setAttr("shapes", new Shape[] {Shape.of(2, 2), Shape.of(2, 2, 2)}) .build(); // bool diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java index 09d2214cc6a..c702f3e9883 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java @@ -23,6 +23,7 @@ import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.op.Ops; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -46,13 +47,13 @@ public void failToCreateIfSessionIsClosed() { @Test public void outputDataTypeAndShape() { try (EagerSession session = EagerSession.create(); - Tensor t = TInt32.tensorOf(Shape.of(2, 3))) { + TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) { EagerOperation op = opBuilder(session, "Const", "OutputAttrs") - .setAttr("dtype", TInt32.DTYPE) + .setAttr("dtype", TInt32.class) .setAttr("value", t) .build(); - assertEquals(TInt32.DTYPE, op.dtype(0)); + assertEquals(DataType.DT_INT32, op.dtype(0)); assertEquals(2, op.shape(0).size(0)); assertEquals(3, op.shape(0).size(1)); } @@ -67,12 +68,12 @@ public void outputTensor() { .addInput(tf.constant(2).asOutput()) .addInput(tf.constant(4).asOutput()) .build(); - assertEquals(6, add.tensor(0).expect(TInt32.DTYPE).data().getInt()); + assertEquals(6, ((TInt32)Tensors.fromHandle(add.tensor(0))).getInt()); // Validate that we retrieve the right shape and datatype from the tensor // that has been resolved assertEquals(0, add.shape(0).numDimensions()); - assertEquals(TInt32.DTYPE, add.dtype(0)); + assertEquals(DataType.DT_INT32, add.dtype(0)); } } @@ -123,7 +124,7 @@ public void numOutputs() { opBuilder(session, "UniqueWithCountsV2", "unq") .addInput(tf.constant(new int[]{1, 2, 1}).asOutput()) .addInput(tf.constant(new int[]{0}).asOutput()) - .setAttr("out_idx", TInt32.DTYPE) + .setAttr("out_idx", TInt32.class) .build(); assertEquals(3, op.numOutputs()); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java index 35bfa808238..4d05dd9ad26 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java @@ -50,12 +50,12 @@ public void failWhenMixingOperationsOnDifferentGraphs() { @Test public void failOnUseAfterBuild() { try (Graph g = new Graph(); - Tensor t = TInt32.scalarOf(1)) { + TInt32 t = TInt32.scalarOf(1)) { OperationBuilder b = - g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); + g.opBuilder("Const", "Const").setAttr("dtype", t.type()).setAttr("value", t); b.build(); try { - b.setAttr("dtype", t.dataType()); + b.setAttr("dtype", t.type()); } catch (IllegalStateException e) { // expected exception. } @@ -66,8 +66,8 @@ public void failOnUseAfterBuild() { public void failOnUseAfterGraphClose() { OperationBuilder b = null; try (Graph g = new Graph(); - Tensor t = TInt32.scalarOf(1)) { - b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); + TInt32 t = TInt32.scalarOf(1)) { + b = g.opBuilder("Const", "Const").setAttr("dtype", t.type()).setAttr("value", t); } try { b.build(); @@ -88,9 +88,9 @@ public void setAttr() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); // dtype, tensor attributes. - try (Tensor t = TInt32.scalarOf(1)) { + try (TInt32 t = TInt32.scalarOf(1)) { g.opBuilder("Const", "DataTypeAndTensor") - .setAttr("dtype", TInt32.DTYPE) + .setAttr("dtype", TInt32.class) .setAttr("value", t) .build() .output(0); @@ -106,7 +106,7 @@ public void setAttr() { g.opBuilder("RandomUniform", "Int") .addInput(tf.array(1).asOutput()) .setAttr("seed", 10) - .setAttr("dtype", TFloat32.DTYPE) + .setAttr("dtype", TFloat32.class) .build(); assertTrue(hasNode(g, "Int")); // list(int) @@ -132,23 +132,23 @@ public void setAttrShape() { try (Graph g = new Graph()) { Output n = g.opBuilder("Placeholder", "unknown") - .setAttr("dtype", TFloat32.DTYPE) + .setAttr("dtype", TFloat32.class) .setAttr("shape", Shape.unknown()) .build() .output(0); assertEquals(-1, n.shape().numDimensions()); - assertEquals(TFloat32.DTYPE, n.dataType()); + assertEquals(TFloat32.class, n.type()); n = g.opBuilder("Placeholder", "batch_of_vectors") - .setAttr("dtype", TFloat32.DTYPE) + .setAttr("dtype", TFloat32.class) .setAttr("shape", Shape.of(-1, 784)) .build() .output(0); assertEquals(2, n.shape().numDimensions()); assertEquals(-1, n.shape().size(0)); assertEquals(784, n.shape().size(1)); - assertEquals(TFloat32.DTYPE, n.dataType()); + assertEquals(TFloat32.class, n.type()); } } @@ -169,10 +169,10 @@ public void setAttrShapeList() { public void addControlInput() { try (Graph g = new Graph(); Session s = new Session(g); - Tensor yes = TBool.scalarOf(true); - Tensor no = TBool.scalarOf(false)) { + TBool yes = TBool.scalarOf(true); + TBool no = TBool.scalarOf(false)) { Ops tf = Ops.create(g); - Output placeholder = tf.placeholder(TBool.DTYPE).asOutput(); + Output placeholder = tf.placeholder(TBool.class).asOutput(); GraphOperation check = g.opBuilder("Assert", "assert") .addInput(placeholder) @@ -200,7 +200,7 @@ private static void testSetAttrShapeList(Shape[] shapes) { int[][] matrix = new int[][] {{0, 0}, {0, 0}}; Output queue = g.opBuilder("FIFOQueue", "queue") - .setAttr("component_types", new DataType[] {TInt32.DTYPE, TInt32.DTYPE}) + .setAttr("component_types", new Class[] {TInt32.class, TInt32.class}) .setAttr("shapes", shapes) .build() .output(0); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index de376015e3f..64643c0c58c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -28,6 +28,7 @@ import org.tensorflow.op.Ops; import org.tensorflow.op.linalg.MatMul; import org.tensorflow.proto.framework.GraphDef; +import org.tensorflow.util.TensorList; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -42,7 +43,7 @@ public void graphDefRoundTrip() { Ops tf = Ops.create(g); tf.withName("Y").linalg.matMul( tf.withName("A").constant(new int[2][2]), - tf.withName("X").placeholder(TInt32.DTYPE), + tf.withName("X").placeholder(TInt32.class), MatMul.transposeA(true).transposeB(false) ); graphDef = g.toGraphDef(); @@ -140,8 +141,8 @@ public void addGradientsToGraph() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output x1 = tf.placeholder(TFloat32.DTYPE).output(); - Output x2 = tf.placeholder(TFloat32.DTYPE).output(); + Output x1 = tf.placeholder(TFloat32.class).output(); + Output x2 = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x1).y(); Output y1 = tf.math.square(y0).y(); Output y2 = tf.math.addN(Arrays.asList(y0, x2)).sum(); @@ -149,29 +150,28 @@ public void addGradientsToGraph() { Output[] grads0 = g.addGradients(y1, toArray(x1)); assertNotNull(grads0); assertEquals(1, grads0.length); - assertEquals(TFloat32.DTYPE, grads0[0].dataType()); + assertEquals(TFloat32.class, grads0[0].type()); Output[] grads1 = g.addGradients(y2, toArray(x1, x2)); assertNotNull(grads1); assertEquals(2, grads1.length); - assertEquals(TFloat32.DTYPE, grads1[0].dataType()); - assertEquals(TFloat32.DTYPE, grads1[1].dataType()); + assertEquals(TFloat32.class, grads1[0].type()); + assertEquals(TFloat32.class, grads1[1].type()); - try (Tensor c1 = TFloat32.scalarOf(3.0f); - Tensor c2 = TFloat32.scalarOf(2.0f); - AutoCloseableList> outputs = new AutoCloseableList<>( - s.runner() + try (TFloat32 c1 = TFloat32.scalarOf(3.0f); + TFloat32 c2 = TFloat32.scalarOf(2.0f); + TensorList outputs = s.runner() .feed(x1, c1) .feed(x2, c2) .fetch(grads0[0]) .fetch(grads1[0]) .fetch(grads1[1]) - .run())) { + .run()) { assertEquals(3, outputs.size()); - assertEquals(108.0f, outputs.get(0).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); - assertEquals(6.0f, outputs.get(1).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); - assertEquals(1.0f, outputs.get(2).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); + assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); + assertEquals(6.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); + assertEquals(1.0f, ((TFloat32)outputs.get(2)).getFloat(), 0.0f); } } } @@ -182,23 +182,22 @@ public void addGradientSumsToGraph() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); Output[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null); assertNotNull(grad); assertEquals(1, grad.length); - assertEquals(TFloat32.DTYPE, grad[0].dataType()); + assertEquals(TFloat32.class, grad[0].type()); - try (Tensor c = TFloat32.scalarOf(3.0f); - Tensor output = s.runner() + try (TFloat32 c = TFloat32.scalarOf(3.0f); + TFloat32 output = s.runner() .feed(x, c) .fetch(grad[0]) .run() - .get(0) - .expect(TFloat32.DTYPE)) { - assertEquals(114.0f, output.data().getFloat(), 0.0f); + .single()) { + assertEquals(114.0f, output.getFloat(), 0.0f); } } } @@ -209,28 +208,27 @@ public void addGradientsWithInitialValuesToGraph() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); Output[] grad0 = g.addGradients(y1, toArray(y0)); assertNotNull(grad0); assertEquals(1, grad0.length); - assertEquals(TFloat32.DTYPE, grad0[0].dataType()); + assertEquals(TFloat32.class, grad0[0].type()); Output[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0])); assertNotNull(grad1); assertEquals(1, grad1.length); - assertEquals(TFloat32.DTYPE, grad1[0].dataType()); + assertEquals(TFloat32.class, grad1[0].type()); - try (Tensor c = TFloat32.scalarOf(3.0f); - Tensor output = s.runner() + try (TFloat32 c = TFloat32.scalarOf(3.0f); + TFloat32 output = s.runner() .feed(x, c) .fetch(grad1[0]) .run() - .get(0) - .expect(TFloat32.DTYPE)) { - assertEquals(108.0f, output.data().getFloat(), 0.0f); + .single()) { + assertEquals(108.0f, output.getFloat(), 0.0f); } } } @@ -240,7 +238,7 @@ public void validateGradientsNames() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null); @@ -269,7 +267,7 @@ public void buildWhileLoopSingleInput() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output input = tf.placeholder(TInt32.DTYPE).output(); + Output input = tf.placeholder(TInt32.class).output(); @SuppressWarnings("unchecked") Output[] loopOutputs = g.whileLoop( @@ -284,14 +282,13 @@ public void buildWhileLoopSingleInput() { }, "test_loop"); - try (Tensor c = TInt32.scalarOf(2); - Tensor output = s.runner() + try (TInt32 c = TInt32.scalarOf(2); + TInt32 output = s.runner() .feed(input, c) .fetch(loopOutputs[0]) .run() - .get(0) - .expect(TInt32.DTYPE)) { - assertEquals(16, output.data().getInt()); // ((2^2)^2) + .single()) { + assertEquals(16, output.getInt()); // ((2^2)^2) } } } @@ -302,8 +299,8 @@ public void buildWhileLoopMultipleInputs() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output input1 = tf.placeholder(TInt32.DTYPE).output(); - Output input2 = tf.placeholder(TInt32.DTYPE).output(); + Output input1 = tf.placeholder(TInt32.class).output(); + Output input2 = tf.placeholder(TInt32.class).output(); Output[] inputs = toArray(input1, input2); @SuppressWarnings("unchecked") @@ -320,19 +317,17 @@ public void buildWhileLoopMultipleInputs() { }, "test_loop"); - try (Tensor c1 = TInt32.scalarOf(2); - Tensor c2 = TInt32.scalarOf(5); - AutoCloseableList> outputs = - new AutoCloseableList<>( - s.runner() - .feed(input1, c1) - .feed(input2, c2) - .fetch(loopOutputs[0]) - .fetch(loopOutputs[1]) - .run())) { + try (TInt32 c1 = TInt32.scalarOf(2); + TInt32 c2 = TInt32.scalarOf(5); + TensorList outputs = s.runner() + .feed(input1, c1) + .feed(input2, c2) + .fetch(loopOutputs[0]) + .fetch(loopOutputs[1]) + .run()) { assertEquals(2, outputs.size()); - assertEquals(16, outputs.get(0).expect(TInt32.DTYPE).data().getInt()); // ((2^2)^2) - assertEquals(625, outputs.get(1).expect(TInt32.DTYPE).data().getInt()); // ((5^2)^2) + assertEquals(16, ((TInt32)outputs.get(0)).getInt()); // ((2^2)^2) + assertEquals(625, ((TInt32)outputs.get(1)).getInt()); // ((5^2)^2) } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index d807d13de00..67a0f9d105b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -44,6 +44,7 @@ import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.util.TensorMap; import org.tensorflow.types.TFloat32; /** Unit tests for {@link org.tensorflow.SavedModelBundle}. */ @@ -107,9 +108,9 @@ public void exportFunctionWithVariables() throws IOException { f.session().run(Init.DEFAULT_NAME); // Call the graph and remember the result of computation for later - try (Tensor xTensor = TFloat32.tensorOf(xValue); - Tensor zTensor = f.call(xTensor).expect(TFloat32.DTYPE)) { - reducedSum = zTensor.data().getFloat(); + try (TFloat32 xTensor = TFloat32.tensorOf(xValue); + TFloat32 zTensor = f.call(xTensor)) { + reducedSum = zTensor.getFloat(); } // Save/export the model (which is a single function in this case) f.save(testFolder.toString()); @@ -153,15 +154,15 @@ public void exportFunctionWithVariables() throws IOException { assertNotNull(outputInfo); assertEquals(0, outputInfo.getTensorShape().getDimCount()); - try (Tensor xTensor = TFloat32.tensorOf(xValue)) { + try (TFloat32 xTensor = TFloat32.tensorOf(xValue)) { // Call the saved model function and make sure it returns the same result as before - try (Tensor zTensor = function.call(xTensor).expect(TFloat32.DTYPE)) { - assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON); + try (TFloat32 zTensor = function.call(xTensor)) { + assertEquals(reducedSum, zTensor.getFloat(), EPSILON); } // Now call the same function directly from the model - try (Tensor zTensor = - savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum").expect(TFloat32.DTYPE)) { - assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON); + try (TFloat32 zTensor = + savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum")) { + assertEquals(reducedSum, zTensor.getFloat(), EPSILON); } } } @@ -179,9 +180,9 @@ public void exportMultipleFunctions() throws IOException { ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { f1.session().run(Init.DEFAULT_NAME); - try (Tensor x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); - Tensor t = f1.call(x).expect(TFloat32.DTYPE)) { - reducedSum = t.data().getFloat(); + try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); + TFloat32 t = f1.call(x)) { + reducedSum = t.getFloat(); } SavedModelBundle.exporter(testFolder.toString()) .withFunction(f1) @@ -193,15 +194,15 @@ public void exportMultipleFunctions() throws IOException { assertEquals(2, model.signatures().size()); ConcreteFunction f1 = model.function(Signature.DEFAULT_KEY); assertNotNull(f1); - try (Tensor x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); - Tensor t = f1.call(x).expect(TFloat32.DTYPE)) { - assertEquals(reducedSum, t.data().getFloat(), EPSILON); + try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); + TFloat32 t = f1.call(x)) { + assertEquals(reducedSum, t.getFloat(), EPSILON); } ConcreteFunction f2 = model.function("identity"); assertNotNull(f2); - try (Tensor x = TFloat32.scalarOf(10.0f); - Tensor t = f2.call(x).expect(TFloat32.DTYPE)) { - assertEquals(10.0f, t.data().getFloat(), 0.0f); + try (TFloat32 x = TFloat32.scalarOf(10.0f); + TFloat32 t = f2.call(x)) { + assertEquals(10.0f, t.getFloat(), 0.0f); } try { model.function("NoSuchFunction"); @@ -290,31 +291,29 @@ public void pythonTfFunction() { * Signature name used for saving 'add', argument names 'a' and 'b' */ ConcreteFunction add = bundle.function("add"); - Map> args = new HashMap(); - try (Tensor a = TFloat32.scalarOf(10.0f); - Tensor b = TFloat32.scalarOf(15.5f)) { - args.put("a", a); - args.put("b", b); - Map> result = add.call(args); + try (TensorMap args = new TensorMap()) { + args.put("a", TFloat32.scalarOf(10.0f)); + args.put("b", TFloat32.scalarOf(15.5f)); + TensorMap result = add.call(args); assertEquals(result.size(), 1); - try (Tensor c = result.values().iterator().next().expect(TFloat32.DTYPE)) { - assertEquals(25.5f, c.data().getFloat()); + try (TFloat32 c = result.single()) { + assertEquals(25.5f, c.getFloat()); } } } } private static Signature buildGraphWithVariables(Ops tf, Shape xShape) { - Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(xShape)); + Placeholder x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape)); Variable y = tf - .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.DTYPE)); + .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class)); ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); Init init = tf.init(); return Signature.builder().input("input", x).output("reducedSum", z).build(); } private static Signature buildIdentityGraph(Ops tf, String signatureKey) { - Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + Placeholder x = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.scalar())); Identity xprime = tf.identity(x); return Signature.builder().key(signatureKey).input("x", x).output("x", xprime).build(); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index fa41af32a29..cb179ac34ff 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -23,7 +23,6 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; -import java.nio.file.Paths; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Init; @@ -36,6 +35,7 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.util.TensorList; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -48,11 +48,10 @@ public void runUsingOperationNames() { Session s = new Session(g)) { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); - try (Tensor x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList> outputs = - new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) { + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); + TensorList outputs = s.runner().feed("X", x).fetch("Y").run()) { assertEquals(1, outputs.size()); - assertEquals(31, outputs.get(0).expect(TInt32.DTYPE).data().getInt(0, 0)); + assertEquals(31, ((TInt32)outputs.single()).getInt(0, 0)); } } } @@ -65,11 +64,10 @@ public void runUsingOperationHandles() { transpose_A_times_X(tf, new int[][] {{2}, {3}}); Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); - try (Tensor x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList> outputs = - new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) { + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); + TensorList outputs = s.runner().feed(feed, x).fetch(fetch).run()) { assertEquals(1, outputs.size()); - assertEquals(31, outputs.get(0).expect(TInt32.DTYPE).data().getInt(0, 0)); + assertEquals(31, ((TInt32)outputs.single()).getInt(0, 0)); } } } @@ -83,22 +81,20 @@ public void runUsingColonSeparatedNames() { tf.math.add(split.output().get(0), split.output().get(1)); // Fetch using colon separated names. - try (Tensor fetched = - s.runner().fetch("Split:1").run().get(0).expect(TInt32.DTYPE)) { - assertEquals(3, fetched.data().getInt(0)); - assertEquals(4, fetched.data().getInt(1)); + try (TInt32 fetched = s.runner().fetch("Split:1").run().get(0)) { + assertEquals(3, fetched.getInt(0)); + assertEquals(4, fetched.getInt(1)); } // Feed using colon separated names. - try (Tensor fed = TInt32.vectorOf(4, 3, 2, 1); - Tensor fetched = + try (TInt32 fed = TInt32.vectorOf(4, 3, 2, 1); + TInt32 fetched = s.runner() .feed("Split:0", fed) .feed("Split:1", fed) .fetch("Add") .run() - .get(0) - .expect(TInt32.DTYPE)) { - assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched.data()); + .single()) { + assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched); } } } @@ -109,7 +105,7 @@ public void runWithMetadata() { Session s = new Session(g)) { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); - try (Tensor x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { Session.Run result = s.runner() .feed("X", x) @@ -117,9 +113,9 @@ public void runWithMetadata() { .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList> outputs = new AutoCloseableList<>(result.outputs); + TensorList outputs = result.outputs; assertEquals(1, outputs.size()); - assertEquals(31, outputs.get(0).expect(TInt32.DTYPE).data().getInt(0, 0)); + assertEquals(31, ((TInt32)outputs.single()).getInt(0, 0)); // Sanity check on metadata assertNotNull(result.metadata); assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); @@ -135,11 +131,10 @@ public void runMultipleOutputs() { Ops tf = Ops.create(g); tf.withName("c1").constant(2718); tf.withName("c2").constant(31415); - AutoCloseableList> outputs = - new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run()); + TensorList outputs = s.runner().fetch("c2").fetch("c1").run(); assertEquals(2, outputs.size()); - assertEquals(31415, outputs.get(0).expect(TInt32.DTYPE).data().getInt()); - assertEquals(2718, outputs.get(1).expect(TInt32.DTYPE).data().getInt()); + assertEquals(31415, ((TInt32)outputs.get(0)).getInt()); + assertEquals(2718, ((TInt32)outputs.get(1)).getInt()); outputs.close(); } } @@ -169,7 +164,7 @@ public void runInit() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Variable var1 = tf.variable(Shape.scalar(), TInt32.DTYPE); + Variable var1 = tf.variable(Shape.scalar(), TInt32.class); tf.initAdd(tf.assign(var1, tf.constant(10))); Variable var2 = tf.variable(tf.constant(20)); Add add = tf.math.add(var1, var2); @@ -177,8 +172,8 @@ public void runInit() { try (Session s = new Session(g)) { s.run(tf.init()); - try (Tensor t = s.runner().fetch(add).run().get(0).expect(TInt32.DTYPE)) { - assertEquals(30, t.data().getInt()); + try (TInt32 t = s.runner().fetch(add).run().single()) { + assertEquals(30, t.getInt()); } } } @@ -189,7 +184,7 @@ public void runInitByName() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Variable var1 = tf.variable(Shape.scalar(), TInt32.DTYPE); + Variable var1 = tf.variable(Shape.scalar(), TInt32.class); tf.initAdd(tf.assign(var1, tf.constant(10))); Variable var2 = tf.variable(tf.constant(20)); Add add = tf.math.add(var1, var2); @@ -198,8 +193,8 @@ public void runInitByName() { try (Session s = new Session(g)) { s.run("init_test"); - try (Tensor t = s.runner().fetch(add).run().get(0).expect(TInt32.DTYPE)) { - assertEquals(30, t.data().getInt()); + try (TInt32 t = s.runner().fetch(add).run().single()) { + assertEquals(30, t.getInt()); } try { s.run("wrong_name"); @@ -216,8 +211,8 @@ public void save() throws IOException { Path testFolder = Files.createTempDirectory("tf-session-save-test"); try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Variable x = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.DTYPE)); - Variable y = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.DTYPE)); + Variable x = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + Variable y = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); Init init = tf.init(); try (Session s = new Session(g)) { @@ -245,7 +240,7 @@ private static ConfigProto singleThreadConfigProto() { private static void transpose_A_times_X(Ops tf, int[][] a) { tf.withName("Y").linalg.matMul( tf.withName("A").constant(a), - tf.withName("X").placeholder(TInt32.DTYPE), + tf.withName("X").placeholder(TInt32.class), MatMul.transposeA(true).transposeB(false) ); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java index 01ef11efedd..601cf2e6de5 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java @@ -30,9 +30,6 @@ import java.nio.IntBuffer; import java.nio.LongBuffer; import org.junit.jupiter.api.Test; -import org.tensorflow.op.Ops; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.FloatNdArray; @@ -40,7 +37,10 @@ import org.tensorflow.ndarray.LongNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.op.Ops; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; @@ -48,8 +48,10 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.TUint8; +import org.tensorflow.types.Type; +import org.tensorflow.types.TypeRegistry; -/** Unit tests for {@link org.tensorflow.Tensor}. */ +/** Unit tests for {@link Tensor}. */ public class TensorTest { private static final double EPSILON = 1e-7; private static final float EPSILON_F = 1e-7f; @@ -64,22 +66,22 @@ public void createWithRawData() { String strings = "test"; Shape strings_shape = Shape.scalar(); byte[] strings_; // raw TF_STRING - try (Tensor t = TString.tensorOf(NdArrays.scalarOfObject(strings))) { + try (TString t = TString.tensorOf(NdArrays.scalarOfObject(strings))) { strings_ = new byte[(int)t.numBytes()]; t.rawData().read(strings_); } // validate creating a tensor using a raw data byte buffers { - try (Tensor t = Tensor.of(TBool.DTYPE, bools_shape, DataBuffers.of(bools_))) { + try (TBool t = Tensors.of(TBool.class, bools_shape, DataBuffers.of(bools_))) { boolean[] actual = new boolean[bools_.length]; - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertArrayEquals(bools, actual); } // note: the buffer is expected to contain raw TF_STRING (as per C API) - try (Tensor t = Tensor.of(TString.DTYPE, strings_shape, DataBuffers.of(strings_))) { - assertEquals(strings, t.data().getObject()); + try (TString t = Tensors.of(TString.class, strings_shape, DataBuffers.of(strings_))) { + assertEquals(strings, t.getObject()); } } @@ -87,15 +89,15 @@ public void createWithRawData() { { DoubleBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder()) .asDoubleBuffer().put(doubles); - try (Tensor t = TFloat64.tensorOf(doubles_shape, d -> d.write(DataBuffers.of(buf)))) { + try (TFloat64 t = TFloat64.tensorOf(doubles_shape, d -> d.write(DataBuffers.of(buf)))) { double[] actual = new double[doubles.length]; - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertArrayEquals(doubles, actual, EPSILON); } } // validate shape checking - try (Tensor t = Tensor.of(TBool.DTYPE, Shape.of(bools_.length * 2), DataBuffers.of(bools_))) { + try (TBool t = Tensors.of(TBool.class, Shape.of(bools_.length * 2), DataBuffers.of(bools_))) { fail("should have failed on incompatible buffer"); } catch (IllegalArgumentException e) { // expected @@ -111,9 +113,9 @@ public void createFromBufferWithNativeByteOrder() { .asDoubleBuffer() .put(doubles); flipBuffer(buf); - try (Tensor t = TFloat64.tensorOf(Shape.of(4), DataBuffers.of(buf))) { + try (TFloat64 t = TFloat64.tensorOf(Shape.of(4), DataBuffers.of(buf))) { double[] actual = new double[doubles.length]; - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertArrayEquals(doubles, actual, EPSILON); } } @@ -130,9 +132,9 @@ public void createFromBufferWithNonNativeByteOrder() { .asDoubleBuffer() .put(doubles); flipBuffer(buf); - try (Tensor t = TFloat64.tensorOf(Shape.of(4), DataBuffers.of(buf))) { + try (TFloat64 t = TFloat64.tensorOf(Shape.of(4), DataBuffers.of(buf))) { double[] actual = new double[doubles.length]; - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertArrayEquals(doubles, actual, EPSILON); } } @@ -147,24 +149,24 @@ public void createWithTypedBuffer() { // validate creating a tensor using a typed buffer { Shape shape = Shape.of(4); - try (Tensor t = TFloat64.tensorOf(shape, DataBuffers.of(doubles))) { + try (TFloat64 t = TFloat64.tensorOf(shape, DataBuffers.of(doubles))) { DoubleBuffer actual = DoubleBuffer.allocate(doubles.capacity()); - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertEquals(doubles, actual); } - try (Tensor t = TFloat32.tensorOf(shape, DataBuffers.of(floats))) { + try (TFloat32 t = TFloat32.tensorOf(shape, DataBuffers.of(floats))) { FloatBuffer actual = FloatBuffer.allocate(floats.capacity()); - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertEquals(floats, actual); } - try (Tensor t = TInt32.tensorOf(shape, DataBuffers.of(ints))) { + try (TInt32 t = TInt32.tensorOf(shape, DataBuffers.of(ints))) { IntBuffer actual = IntBuffer.allocate(ints.capacity()); - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertEquals(ints, actual); } - try (Tensor t = TInt64.tensorOf(shape, DataBuffers.of(longs))) { + try (TInt64 t = TInt64.tensorOf(shape, DataBuffers.of(longs))) { LongBuffer actual = LongBuffer.allocate(longs.capacity()); - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertEquals(longs, actual); } } @@ -172,22 +174,22 @@ public void createWithTypedBuffer() { // validate shape-checking { Shape shape = Shape.of(5); - try (Tensor t = TFloat64.tensorOf(shape, DataBuffers.of(doubles))) { + try (TFloat64 t = TFloat64.tensorOf(shape, DataBuffers.of(doubles))) { fail("should have failed on incompatible buffer"); } catch (BufferUnderflowException e) { // expected } - try (Tensor t = TFloat32.tensorOf(shape, DataBuffers.of(floats))) { + try (TFloat32 t = TFloat32.tensorOf(shape, DataBuffers.of(floats))) { fail("should have failed on incompatible buffer"); } catch (BufferUnderflowException e) { // expected } - try (Tensor t = TInt32.tensorOf(shape, DataBuffers.of(ints))) { + try (TInt32 t = TInt32.tensorOf(shape, DataBuffers.of(ints))) { fail("should have failed on incompatible buffer"); } catch (BufferUnderflowException e) { // expected } - try (Tensor t = TInt64.tensorOf(shape, DataBuffers.of(longs))) { + try (TInt64 t = TInt64.tensorOf(shape, DataBuffers.of(longs))) { fail("should have failed on incompatible buffer"); } catch (BufferUnderflowException e) { // expected @@ -203,11 +205,11 @@ public void readFromRawData() { long[] longs = {1L, 2L, 3L}; boolean[] bools = {true, false, true}; - try (Tensor tints = TInt32.vectorOf(ints); - Tensor tfloats = TFloat32.vectorOf(floats); - Tensor tdoubles = TFloat64.vectorOf(doubles); - Tensor tlongs = TInt64.vectorOf(longs); - Tensor tbools = TBool.vectorOf(bools)) { + try (TInt32 tints = TInt32.vectorOf(ints); + TFloat32 tfloats = TFloat32.vectorOf(floats); + TFloat64 tdoubles = TFloat64.vectorOf(doubles); + TInt64 tlongs = TInt64.vectorOf(longs); + TBool tbools = TBool.vectorOf(bools)) { // validate that any datatype is readable with ByteBuffer (content, position) { @@ -266,79 +268,79 @@ public void readFromRawData() { @Test public void scalars() { - try (Tensor t = TFloat32.scalarOf(2.718f)) { - assertEquals(TFloat32.DTYPE, t.dataType()); + try (TFloat32 t = TFloat32.scalarOf(2.718f)) { + assertEquals(TFloat32.class, t.type()); assertEquals(0, t.shape().numDimensions()); - assertEquals(2.718f, t.data().getFloat(), EPSILON_F); + assertEquals(2.718f, t.getFloat(), EPSILON_F); } - try (Tensor t = TFloat64.scalarOf(3.1415)) { - assertEquals(TFloat64.DTYPE, t.dataType()); + try (TFloat64 t = TFloat64.scalarOf(3.1415)) { + assertEquals(TFloat64.class, t.type()); assertEquals(0, t.shape().numDimensions()); - assertEquals(3.1415, t.data().getDouble(), EPSILON); + assertEquals(3.1415, t.getDouble(), EPSILON); } - try (Tensor t = TInt32.scalarOf(-33)) { - assertEquals(TInt32.DTYPE, t.dataType()); + try (TInt32 t = TInt32.scalarOf(-33)) { + assertEquals(TInt32.class, t.type()); assertEquals(0, t.shape().numDimensions()); - assertEquals(-33, t.data().getInt()); + assertEquals(-33, t.getInt()); } - try (Tensor t = TInt64.scalarOf(8589934592L)) { - assertEquals(TInt64.DTYPE, t.dataType()); + try (TInt64 t = TInt64.scalarOf(8589934592L)) { + assertEquals(TInt64.class, t.type()); assertEquals(0, t.shape().numDimensions()); - assertEquals(8589934592L, t.data().getLong()); + assertEquals(8589934592L, t.getLong()); } - try (Tensor t = TBool.scalarOf(true)) { - assertEquals(TBool.DTYPE, t.dataType()); + try (TBool t = TBool.scalarOf(true)) { + assertEquals(TBool.class, t.type()); assertEquals(0, t.shape().numDimensions()); - assertTrue(t.data().getBoolean()); + assertTrue(t.getBoolean()); } - try (Tensor t = TString.scalarOf("sombrero")) { - assertEquals(TString.DTYPE, t.dataType()); + try (TString t = TString.scalarOf("sombrero")) { + assertEquals(TString.class, t.type()); assertEquals(0, t.shape().numDimensions()); - assertEquals("sombrero", t.data().getObject()); + assertEquals("sombrero", t.getObject()); } final byte[] bytes = {1, 2, 3, 4}; - try (Tensor t = TString.tensorOfBytes(NdArrays.scalarOfObject(bytes))) { - assertEquals(TString.DTYPE, t.dataType()); + try (TString t = TString.tensorOfBytes(NdArrays.scalarOfObject(bytes))) { + assertEquals(TString.class, t.type()); assertEquals(0, t.shape().numDimensions()); - assertArrayEquals(bytes, t.data().asBytes().getObject()); + assertArrayEquals(bytes, t.asBytes().getObject()); } } @Test public void nDimensional() { DoubleNdArray vector = StdArrays.ndCopyOf(new double[]{1.414, 2.718, 3.1415}); - try (Tensor t = TFloat64.tensorOf(vector)) { - assertEquals(TFloat64.DTYPE, t.dataType()); + try (TFloat64 t = TFloat64.tensorOf(vector)) { + assertEquals(TFloat64.class, t.type()); assertEquals(1, t.shape().numDimensions()); assertEquals(3, t.shape().size(0)); - assertEquals(vector, t.data()); + assertEquals(vector, t); } IntNdArray matrix = StdArrays.ndCopyOf(new int[][]{{1, 2, 3}, {4, 5, 6}}); - try (Tensor t = TInt32.tensorOf(matrix)) { - assertEquals(TInt32.DTYPE, t.dataType()); + try (TInt32 t = TInt32.tensorOf(matrix)) { + assertEquals(TInt32.class, t.type()); assertEquals(2, t.shape().numDimensions()); assertEquals(2, t.shape().size(0)); assertEquals(3, t.shape().size(1)); - assertEquals(matrix, t.data()); + assertEquals(matrix, t); } LongNdArray threeD = StdArrays.ndCopyOf(new long[][][]{ {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}}, }); - try (Tensor t = TInt64.tensorOf(threeD)) { - assertEquals(TInt64.DTYPE, t.dataType()); + try (TInt64 t = TInt64.tensorOf(threeD)) { + assertEquals(TInt64.class, t.type()); assertEquals(3, t.shape().numDimensions()); assertEquals(2, t.shape().size(0)); assertEquals(5, t.shape().size(1)); assertEquals(1, t.shape().size(2)); - assertEquals(threeD, t.data()); + assertEquals(threeD, t); } BooleanNdArray fourD = StdArrays.ndCopyOf(new boolean[][][][]{ @@ -346,14 +348,14 @@ public void nDimensional() { {{{false, false, true, true}, {false, true, false, false}}}, {{{false, true, false, true}, {false, true, true, false}}}, }); - try (Tensor t = TBool.tensorOf(fourD)) { - assertEquals(TBool.DTYPE, t.dataType()); + try (TBool t = TBool.tensorOf(fourD)) { + assertEquals(TBool.class, t.type()); assertEquals(4, t.shape().numDimensions()); assertEquals(3, t.shape().size(0)); assertEquals(1, t.shape().size(1)); assertEquals(2, t.shape().size(2)); assertEquals(4, t.shape().size(3)); - assertEquals(fourD, t.data()); + assertEquals(fourD, t); } } @@ -365,36 +367,36 @@ public void testNDimensionalStringTensor() { matrix.setObject(String.format("(%d, %d) = %d", i, j, i << j), i, j); } } - try (Tensor t = TString.tensorOf(matrix)) { - assertEquals(TString.DTYPE, t.dataType()); + try (TString t = TString.tensorOf(matrix)) { + assertEquals(TString.class, t.type()); assertEquals(2, t.shape().numDimensions()); assertEquals(4, t.shape().size(0)); assertEquals(3, t.shape().size(1)); - assertEquals(matrix, t.data()); + assertEquals(matrix, t); } NdArray byteMatrix = NdArrays.ofObjects(byte[].class, matrix.shape()); matrix.scalars().forEachIndexed((i, s) -> byteMatrix.setObject(s.getObject().getBytes(UTF_8), i)); - try (Tensor t = TString.tensorOfBytes(byteMatrix)) { - assertEquals(TString.DTYPE, t.dataType()); + try (TString t = TString.tensorOfBytes(byteMatrix)) { + assertEquals(TString.class, t.type()); assertEquals(2, t.shape().numDimensions()); assertEquals(4, t.shape().size(0)); assertEquals(3, t.shape().size(1)); - assertEquals(byteMatrix, t.data().asBytes()); - assertEquals(matrix, t.data()); + assertEquals(byteMatrix, t.asBytes()); + assertEquals(matrix, t); } } @Test public void testUint8TensorFromArray() { byte[] vector = new byte[] {1, 2, 3, 4}; - try (Tensor t = TUint8.vectorOf(vector)) { - assertEquals(TUint8.DTYPE, t.dataType()); + try (TUint8 t = TUint8.vectorOf(vector)) { + assertEquals(TUint8.class, t.type()); assertEquals(1, t.shape().numDimensions()); assertEquals(4, t.shape().size(0)); byte[] got = new byte[4]; - t.data().read(DataBuffers.of(got)); + t.read(DataBuffers.of(got)); assertArrayEquals(vector, got); } } @@ -402,13 +404,13 @@ public void testUint8TensorFromArray() { @Test public void testCreateFromArrayOfBoxed() { Integer[] vector = new Integer[] {1, 2, 3, 4}; - try (Tensor t = TInt32.tensorOf(Shape.of(4), d -> d.write(DataBuffers.ofObjects(vector)))) { - assertEquals(TInt32.DTYPE, t.dataType()); + try (TInt32 t = TInt32.tensorOf(Shape.of(4), d -> d.write(DataBuffers.ofObjects(vector)))) { + assertEquals(TInt32.class, t.type()); assertEquals(1, t.shape().numDimensions()); assertEquals(4, t.shape().size(0)); Integer[] got = new Integer[4]; - t.data().read(DataBuffers.ofObjects(got)); + t.read(DataBuffers.ofObjects(got)); assertArrayEquals(vector, got); } } @@ -421,7 +423,7 @@ public void failCreateOnMismatchedDimensions() { invalid[x][y] = new int[x + y + 1]; } } - try (Tensor t = TInt32.tensorOf(StdArrays.ndCopyOf(invalid))) { + try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(invalid))) { fail("Tensor.create() should fail because of differing sizes in the 3rd dimension"); } catch (IllegalArgumentException e) { // The expected exception. @@ -433,11 +435,11 @@ public void tensorWithZeroDimension() { // Note: Historically, TF Java failed on purpose when trying to allocate a tensor with a shape // that has one or more dimensions set to 0 elements. But Python API allows it, so we should do // the same. - try (Tensor t = TInt32.tensorOf(Shape.of(3, 0, 1))) { + try (TInt32 t = TInt32.tensorOf(Shape.of(3, 0, 1))) { assertEquals(0, t.numBytes()); assertEquals(0, t.shape().size()); } - try (Tensor t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[3][0][1]))) { + try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[3][0][1]))) { assertEquals(0, t.numBytes()); assertEquals(0, t.shape().size()); } @@ -445,53 +447,36 @@ public void tensorWithZeroDimension() { @Test public void allocateTensorWithSize() { - try (Tensor t = Tensor.of(TInt32.DTYPE, Shape.of(2, 2, 2), 8 * TInt32.DTYPE.byteSize())) { + Type type = TypeRegistry.find(TInt32.class); + try (TInt32 t = Tensors.of(TInt32.class, Shape.of(2, 2, 2), 8 * type.byteSize())) { // ok } - try (Tensor t = Tensor.of(TInt32.DTYPE, Shape.of(2, 2, 2), 9 * TInt32.DTYPE.byteSize())) { + try (TInt32 t = Tensors.of(TInt32.class, Shape.of(2, 2, 2), 9 * type.byteSize())) { // ok (size requested is larger that minimum space required) } try { - Tensor.of(TInt32.DTYPE, Shape.of(2, 2, 2), 8 * TInt32.DTYPE.byteSize() - 1); + Tensors.of(TInt32.class, Shape.of(2, 2, 2), 8 * type.byteSize() - 1); fail(); } catch (IllegalArgumentException e) { // as expected } } - @Test - public void useAfterClose() { - int n = 4; - Tensor t = TInt32.scalarOf(n); - t.close(); - try { - t.data(); - } catch (IllegalStateException e) { - // The expected exception. - } - } - @Test public void eagerTensorIsReleasedAfterSessionIsClosed() { - Tensor sum; + TInt32 sum; try (EagerSession session = EagerSession.create()) { Ops tf = Ops.create(session); sum = tf.math.add(tf.constant(10), tf.constant(20)).asTensor(); - sum.nativeHandle(); // does not throw - assertEquals(30, sum.data().getInt()); + sum.handle().get(); // does not throw + assertEquals(30, sum.getInt()); } try { - sum.nativeHandle(); + sum.handle().get(); fail("Tensor native handle should have been closed by ending eager session"); } catch (IllegalStateException e) { // as expected } - try { - sum.data().getInt(); - fail("Tensor data should not be accessible after tensor is closed"); - } catch (IllegalStateException e) { - // as expected - } } @Test @@ -503,12 +488,12 @@ public void fromHandle() { // An exception is made for this test, where the pitfalls of this is avoided by not calling // close() on both Tensors. final FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{{1, 2, 3}, {4, 5, 6}}); - try (Tensor src = TFloat32.tensorOf(matrix)) { - Tensor cpy = Tensor.fromHandle(src.nativeHandle()).expect(TFloat32.DTYPE); - assertEquals(src.dataType(), cpy.dataType()); + try (TFloat32 src = TFloat32.tensorOf(matrix)) { + TFloat32 cpy = Tensors.fromHandle(src.handle()); + assertEquals(src.type(), cpy.type()); assertEquals(src.shape().numDimensions(), cpy.shape().numDimensions()); assertEquals(src.shape(), cpy.shape()); - assertEquals(matrix, cpy.data()); + assertEquals(matrix, cpy); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java index bbebfd5f454..e83c956adb7 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java @@ -23,9 +23,8 @@ import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.Session; -import org.tensorflow.Tensor; -import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TType; +import org.tensorflow.types.TInt32; /** Unit tests for {@link org.tensorflow.op.Scope}. */ public class ScopeTest { @@ -169,11 +168,12 @@ public void composite() { // assertNotNull(g.operation("variance/zero")); // Verify correct results as well. - Tensor result = - sess.runner().fetch(var1.output()).run().get(0).expect(TInt32.DTYPE); - assertEquals(21704, result.data().getInt()); - result = sess.runner().fetch(var2.output()).run().get(0).expect(TInt32.DTYPE); - assertEquals(21704, result.data().getInt()); + try (TInt32 result = sess.runner().fetch(var1.output()).run().single()) { + assertEquals(21704, result.getInt()); + } + try (TInt32 result = sess.runner().fetch(var2.output()).run().single()) { + assertEquals(21704, result.getInt()); + } } } @@ -189,11 +189,11 @@ static Const create(Scope s, int[] v) { return create(s, TInt32.vectorOf(v)); } - static Const create(Scope s, Tensor value) { + static Const create(Scope s, T value) { return new Const<>( s.env() .opBuilder("Const", s.makeOpName("Const")) - .setAttr("dtype", value.dataType()) + .setAttr("dtype", value.type()) .setAttr("value", value) .build() .output(0)); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 266a62bd1ed..0f94a546ca8 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -18,16 +18,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; -import java.nio.LongBuffer; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.Graph; import org.tensorflow.Session; -import org.tensorflow.Tensor; +import org.tensorflow.util.TensorList; import org.tensorflow.op.Scope; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffer; @@ -62,10 +56,9 @@ public void createInts() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList> t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0).expect(TInt32.DTYPE).data()); - assertEquals(array, t.get(1).expect(TInt32.DTYPE).data()); + try (TensorList t = sess.runner().fetch(op1).fetch(op2).run()) { + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } } @@ -81,10 +74,9 @@ public void createFloats() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList> t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0).expect(TFloat32.DTYPE).data()); - assertEquals(array, t.get(1).expect(TFloat32.DTYPE).data()); + try (TensorList t = sess.runner().fetch(op1).fetch(op2).run()) { + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } } @@ -100,10 +92,9 @@ public void createDoubles() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList> t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0).expect(TFloat64.DTYPE).data()); - assertEquals(array, t.get(1).expect(TFloat64.DTYPE).data()); + try (TensorList t = sess.runner().fetch(op1).fetch(op2).run()) { + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } } @@ -119,10 +110,9 @@ public void createLongs() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList> t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0).expect(TInt64.DTYPE).data()); - assertEquals(array, t.get(1).expect(TInt64.DTYPE).data()); + try (TensorList t = sess.runner().fetch(op1).fetch(op2).run()) { + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } } @@ -138,10 +128,9 @@ public void createStrings() throws IOException { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList> t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0).expect(TString.DTYPE).data()); - assertEquals(array, t.get(1).expect(TString.DTYPE).data()); + try (TensorList t = sess.runner().fetch(op1).fetch(op2).run()) { + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java index a337bd73098..e15583cda8c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java @@ -22,7 +22,6 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.ndarray.Shape; @@ -36,8 +35,8 @@ public void tensorInputTensorOutput() { Session sess = new Session(g)) { Ops ops = Ops.create(g); Operand x = ops.math.add(ops.constant(1), ops.constant(2)); - try (Tensor result = sess.runner().fetch(x).run().get(0).expect(TInt32.DTYPE)) { - assertEquals(3, result.data().getInt()); + try (TInt32 result = sess.runner().fetch(x).run().single()) { + assertEquals(3, result.getInt()); } } } @@ -52,8 +51,8 @@ public void testListInputTensorOutput() { inputs.add(ops.constant(2)); inputs.add(ops.constant(3)); Operand x = ops.math.addN(inputs); - try (Tensor result = sess.runner().fetch(x).run().get(0).expect(TInt32.DTYPE)) { - assertEquals(6, result.data().getInt()); + try (TInt32 result = sess.runner().fetch(x).run().single()) { + assertEquals(6, result.getInt()); } } } @@ -70,15 +69,15 @@ public void testControlDependencies() { try (Graph g = new Graph(); Session sess = new Session(g)) { Ops ops = Ops.create(g); - Operand variable = ops.variable(Shape.scalar(), TInt32.DTYPE); + Operand variable = ops.variable(Shape.scalar(), TInt32.class); Operand initVariable = ops.assign(variable, ops.constant(0)); ArrayList controls = new ArrayList<>(); controls.add(ops.assign(variable, ops.constant(3))); Operand x = ops.withControlDependencies(controls).math.add(variable, ops.constant(0)); sess.runner().addTarget(initVariable).run(); - try (Tensor result = sess.runner().fetch(x).run().get(0).expect(TInt32.DTYPE)) { - assertEquals(3, result.data().getInt()); + try (TInt32 result = sess.runner().fetch(x).run().single()) { + assertEquals(3, result.getInt()); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index fe1503d415f..df2488662fc 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -21,12 +21,11 @@ import java.util.Arrays; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.op.Ops; +import org.tensorflow.util.TensorList; import org.tensorflow.types.TFloat32; public class GradientsTest { @@ -37,7 +36,7 @@ public void createGradients() { Session sess = new Session(g)) { Ops tf = Ops.create(g); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); @@ -47,13 +46,10 @@ public void createGradients() { assertNotNull(grads.dy()); assertEquals(2, grads.dy().size()); - try (Tensor c = TFloat32.scalarOf(3.0f); - AutoCloseableList> outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { - - assertEquals(108.0f, outputs.get(0).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); - assertEquals(18.0f, outputs.get(1).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); + try (TFloat32 c = TFloat32.scalarOf(3.0f); + TensorList outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run()) { + assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); + assertEquals(18.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); } } } @@ -64,7 +60,7 @@ public void createGradientsWithSum() { Session sess = new Session(g)) { Ops tf = Ops.create(g); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); @@ -74,11 +70,9 @@ public void createGradientsWithSum() { assertNotNull(grads.dy()); assertEquals(1, grads.dy().size()); - try (Tensor c = TFloat32.scalarOf(3.0f); - AutoCloseableList> outputs = - new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { - - assertEquals(114.0f, outputs.get(0).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); + try (TFloat32 c = TFloat32.scalarOf(3.0f); + TensorList outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run()) { + assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); } } } @@ -89,7 +83,7 @@ public void createGradientsWithInitialValues() { Session sess = new Session(g)) { Ops tf = Ops.create(g); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); @@ -100,12 +94,9 @@ public void createGradientsWithInitialValues() { assertNotNull(grads1.dy()); assertEquals(1, grads1.dy().size()); - try (Tensor c = TFloat32.scalarOf(3.0f); - AutoCloseableList> outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { - - assertEquals(108.0f, outputs.get(0).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); + try (TFloat32 c = TFloat32.scalarOf(3.0f); + TensorList outputs = sess.runner().feed(x, c).fetch(grads1.dy(0)).run()) { + assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); } } } @@ -115,7 +106,7 @@ public void validateGradientsNames() { try (Graph g = new Graph()) { Ops tf = Ops.create(g).withSubScope("sub"); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y = tf.math.square(x).y(); Gradients grad0 = Gradients.create(tf.scope(), y, Arrays.asList(x)); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java index d5eb7412ea3..665d599145e 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java @@ -20,12 +20,10 @@ import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.TestTemplate; import org.tensorflow.Graph; import org.tensorflow.EagerSession; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.op.Scope; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -40,22 +38,19 @@ public void testFlatten_Operand() { Session session = new Session(g)) { Scope scope = new Scope(g); Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Shape expResult = Shape.create(scope, operand, TInt64.DTYPE); + Shape expResult = Shape.create(scope, operand, TInt64.class); Operand reshaped = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); Operand actual = Shapes.flatten(scope, reshaped); - Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); + Shape tfshape = Shape.create(scope, actual, TInt64.class); AtomicInteger index = new AtomicInteger(); - try (Tensor result1 = - session.runner().fetch(tfshape.asOutput()).run().get(0).expect(TInt64.DTYPE); - Tensor result2 = - session.runner().fetch(expResult.asOutput()).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result1 = session.runner().fetch(tfshape.asOutput()).run().single(); + TInt64 result2 = session.runner().fetch(expResult.asOutput()).run().single()) { result1 - .data() .scalars() .forEach( - s -> assertEquals(result2.data().getLong(index.getAndIncrement()), s.getLong())); + s -> assertEquals(result2.getLong(index.getAndIncrement()), s.getLong())); } } } @@ -66,21 +61,21 @@ public void testFlatten_Shape() { try (EagerSession session = EagerSession.create()) { Scope scope = new Scope(session); Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Shape expShape = Shape.create(scope, operand, TInt64.DTYPE); + Shape expShape = Shape.create(scope, operand, TInt64.class); Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); - Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); - Operand flattened = Shapes.flatten(scope, tfshape, TInt64.DTYPE); + Shape tfshape = Shape.create(scope, actual, TInt64.class); + Operand flattened = Shapes.flatten(scope, tfshape, TInt64.class); AtomicInteger index = new AtomicInteger(); flattened .asOutput() - .data() + .asTensor() .scalars() .forEach( s -> assertEquals( - expShape.asOutput().data().getLong(index.getAndIncrement()), s.getLong())); + expShape.asTensor().getLong(index.getAndIncrement()), s.getLong())); } } @@ -93,13 +88,12 @@ public void testSize_Shape() { Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); - Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); - Operand size = Shapes.size(scope, tfshape, TInt64.DTYPE); + Shape tfshape = Shape.create(scope, actual, TInt64.class); + Operand size = Shapes.size(scope, tfshape, TInt64.class); AtomicInteger index = new AtomicInteger(); - try (Tensor result1 = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt64.DTYPE)) { - result1.data().scalars().forEach(s -> assertEquals(8, s.getLong())); + try (TInt64 result1 = session.runner().fetch(size.asOutput()).run().single()) { + result1.scalars().forEach(s -> assertEquals(8, s.getLong())); } } } @@ -116,21 +110,21 @@ public void testSize_Shape_Operand() { Shape tfshape = Shape.create(scope, actual); Operand size = Shapes.size(scope, tfshape, Constant.scalarOf(scope, 0)); - try (Tensor result = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(4, s.getInt())); + try (TInt32 result = + session.runner().fetch(size.asOutput()).run().single()) { + result.scalars().forEach(s -> assertEquals(4, s.getInt())); } size = Shapes.size(scope, tfshape, Constant.scalarOf(scope, 1)); - try (Tensor result = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(2, s.getInt())); + try (TInt32 result = + session.runner().fetch(size.asOutput()).run().single()) { + result.scalars().forEach(s -> assertEquals(2, s.getInt())); } size = Shapes.size(scope, tfshape, Constant.scalarOf(scope, 2)); - try (Tensor result = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(1, s.getInt())); + try (TInt32 result = + session.runner().fetch(size.asOutput()).run().single()) { + result.scalars().forEach(s -> assertEquals(1, s.getInt())); } } } @@ -146,21 +140,21 @@ public void testSize_Operand_Operand() { Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); Operand size = Shapes.size(scope, actual, Constant.scalarOf(scope, 0)); - try (Tensor result = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(4, s.getInt())); + try (TInt32 result = + session.runner().fetch(size.asOutput()).run().single()) { + result.scalars().forEach(s -> assertEquals(4, s.getInt())); } size = Shapes.size(scope, actual, Constant.scalarOf(scope, 1)); - try (Tensor result = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(2, s.getInt())); + try (TInt32 result = + session.runner().fetch(size.asOutput()).run().single()) { + result.scalars().forEach(s -> assertEquals(2, s.getInt())); } size = Shapes.size(scope, actual, Constant.scalarOf(scope, 2)); - try (Tensor result = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(1, s.getInt())); + try (TInt32 result = + session.runner().fetch(size.asOutput()).run().single()) { + result.scalars().forEach(s -> assertEquals(1, s.getInt())); } } } @@ -177,9 +171,9 @@ public void testNumDimensions() { Shape tfshape = Shape.create(scope, actual); Operand nDims = Shapes.numDimensions(scope, tfshape); - try (Tensor result = - session.runner().fetch(nDims.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(3, s.getInt())); + try (TInt32 result = + session.runner().fetch(nDims.asOutput()).run().single()) { + result.scalars().forEach(s -> assertEquals(3, s.getInt())); } } } @@ -199,7 +193,7 @@ public void testReduceDims_Operand_Operand() { AtomicInteger index = new AtomicInteger(); int[] expected = {8}; reducedShape - .data() + .asTensor() .scalars() .forEach( s -> { @@ -224,7 +218,7 @@ public void testReduceDims_Shape_Operand() { AtomicInteger index = new AtomicInteger(); int[] expected1 = {8}; reducedShape - .data() + .asTensor() .scalars() .forEach( s -> { @@ -237,7 +231,7 @@ public void testReduceDims_Shape_Operand() { index.set(0); int[] expected2 = {2, 4}; reducedShape - .data() + .asTensor() .scalars() .forEach( s -> { @@ -250,7 +244,7 @@ public void testReduceDims_Shape_Operand() { index.set(0); int[] expected3 = {2, 2, 2}; reducedShape - .data() + .asTensor() .scalars() .forEach( s -> { @@ -274,10 +268,10 @@ public void testSqueeze() { Operand squeezed = Shapes.squeeze(scope, tfshape); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 2}; - try (Tensor result = - session.runner().fetch(squeezed.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + session.runner().fetch(squeezed.asOutput()).run().single()) { result - .data() + .scalars() .forEach( s -> { @@ -301,10 +295,10 @@ public void testHead() { Operand head = Shapes.head(scope, tfshape); AtomicInteger index = new AtomicInteger(); int[] expected = {4}; - try (Tensor result = - session.runner().fetch(head.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + session.runner().fetch(head.asOutput()).run().single()) { result - .data() + .scalars() .forEach( s -> { @@ -328,10 +322,10 @@ public void testTake() { Operand take = Shapes.take(scope, tfshape, Constant.scalarOf(scope, 2)); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 1}; - try (Tensor result = - session.runner().fetch(take.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + session.runner().fetch(take.asOutput()).run().single()) { result - .data() + .scalars() .forEach( s -> { @@ -355,10 +349,10 @@ public void testTail() { Operand tail = Shapes.tail(scope, tfshape); AtomicInteger index = new AtomicInteger(); int[] expected = {1}; - try (Tensor result = - session.runner().fetch(tail.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + session.runner().fetch(tail.asOutput()).run().single()) { result - .data() + .scalars() .forEach( s -> { @@ -382,10 +376,10 @@ public void testTakeLast() { Operand takeLast = Shapes.takeLast(scope, tfshape, Constant.scalarOf(scope, 3)); AtomicInteger index = new AtomicInteger(); int[] expected = {1, 2, 1}; - try (Tensor result = - session.runner().fetch(takeLast.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + session.runner().fetch(takeLast.asOutput()).run().single()) { result - .data() + .scalars() .forEach( s -> { @@ -408,10 +402,10 @@ public void testPrependInt() { Operand prepend = Shapes.prepend(scope, tfshape, 3); AtomicInteger index = new AtomicInteger(); int[] expected = {3, 4, 2}; - try (Tensor result = - session.runner().fetch(prepend.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + session.runner().fetch(prepend.asOutput()).run().single()) { result - .data() + .scalars() .forEach( s -> { @@ -429,15 +423,15 @@ public void testPrependLong() { Scope scope = new Scope(g); Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); - Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); + Shape tfshape = Shape.create(scope, actual, TInt64.class); Operand prepend = Shapes.prepend(scope, tfshape, 1L); AtomicInteger index = new AtomicInteger(); long[] expected = {1, 4, 2}; - try (Tensor result = - session.runner().fetch(prepend.asOutput()).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + session.runner().fetch(prepend.asOutput()).run().single()) { result - .data() + .scalars() .forEach( s -> { @@ -465,10 +459,10 @@ public void testPrependShapeTInt32() { Operand prepend = Shapes.prepend(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); int[] expected = {2, 4, 4, 2}; - try (Tensor result = - session.runner().fetch(prepend.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + session.runner().fetch(prepend.asOutput()).run().single()) { result - .data() + .scalars() .forEach( s -> { @@ -490,16 +484,16 @@ public void testPrependShapeTInt64() { Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); Operand actual2 = Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); - Shape tfshape1 = Shape.create(scope, actual1, TInt64.DTYPE); - Shape tfshape2 = Shape.create(scope, actual2, TInt64.DTYPE); + Shape tfshape1 = Shape.create(scope, actual1, TInt64.class); + Shape tfshape2 = Shape.create(scope, actual2, TInt64.class); Operand prepend = Shapes.prepend(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); long[] expected = {2, 4, 4, 2}; - try (Tensor result = - session.runner().fetch(prepend.asOutput()).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + session.runner().fetch(prepend.asOutput()).run().single()) { result - .data() + .scalars() .forEach( s -> { @@ -517,15 +511,15 @@ public void testAppendLong() { Scope scope = new Scope(g); Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); - Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); + Shape tfshape = Shape.create(scope, actual, TInt64.class); Operand append = Shapes.append(scope, tfshape, 2L); AtomicInteger index = new AtomicInteger(); long[] expected = {4L, 2L, 2L}; - try (Tensor result = - session.runner().fetch(append.asOutput()).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + session.runner().fetch(append.asOutput()).run().single()) { result - .data() + .scalars() .forEach( s -> { @@ -548,10 +542,10 @@ public void testAppendInt() { Operand append = Shapes.append(scope, tfshape, 2); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 2, 2}; - try (Tensor result = - session.runner().fetch(append.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + session.runner().fetch(append.asOutput()).run().single()) { result - .data() + .scalars() .forEach( s -> { @@ -579,10 +573,10 @@ public void testAppendShapeTInt32() { Operand append = Shapes.append(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 2, 2, 4}; - try (Tensor result = - session.runner().fetch(append.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + session.runner().fetch(append.asOutput()).run().single()) { result - .data() + .scalars() .forEach( s -> { @@ -604,16 +598,15 @@ public void testAppendShapeTInt64() { Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); Operand actual2 = Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); - Shape tfshape1 = Shape.create(scope, actual1, TInt64.DTYPE); - Shape tfshape2 = Shape.create(scope, actual2, TInt64.DTYPE); + Shape tfshape1 = Shape.create(scope, actual1, TInt64.class); + Shape tfshape2 = Shape.create(scope, actual2, TInt64.class); Operand append = Shapes.append(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); long[] expected = {4, 2, 2, 4}; - try (Tensor result = - session.runner().fetch(append.asOutput()).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + session.runner().fetch(append.asOutput()).run().single()) { result - .data() .scalars() .forEach( s -> { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java index 9600f8b38fc..f84c7ce5530 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -19,12 +19,11 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.op.Scope; +import org.tensorflow.util.TensorList; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; @@ -41,9 +40,9 @@ public void createIntZeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt32.DTYPE); - try (Tensor result = sess.runner().fetch(op).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(0, s.getInt())); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt32.class); + try (TInt32 result = sess.runner().fetch(op).run().single()) { + result.scalars().forEach(s -> assertEquals(0, s.getInt())); } } } @@ -54,9 +53,9 @@ public void createFloatZeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.DTYPE); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(0.0f, s.getFloat(), 0)); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.class); + try (TFloat32 result = sess.runner().fetch(op.asOutput()).run().single()) { + result.scalars().forEach(s -> assertEquals(0.0f, s.getFloat(), 0)); } } } @@ -67,9 +66,9 @@ public void createDoubleZeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat64.DTYPE); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TFloat64.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(0.0f, s.getDouble(), 0)); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat64.class); + try (TFloat64 result = sess.runner().fetch(op.asOutput()).run().single()) { + result.scalars().forEach(s -> assertEquals(0.0f, s.getDouble(), 0)); } } } @@ -80,9 +79,9 @@ public void createLongZeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt64.DTYPE); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TInt64.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(0L, s.getLong())); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt64.class); + try (TInt64 result = sess.runner().fetch(op.asOutput()).run().single()) { + result.scalars().forEach(s -> assertEquals(0L, s.getLong())); } } } @@ -93,9 +92,9 @@ public void createBooleanZeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TBool.DTYPE); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TBool.DTYPE)) { - result.data().scalars().forEach(s -> assertFalse(s.getBoolean())); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TBool.class); + try (TBool result = sess.runner().fetch(op.asOutput()).run().single()) { + result.scalars().forEach(s -> assertFalse(s.getBoolean())); } } } @@ -106,9 +105,9 @@ public void createUint8Zeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TUint8.DTYPE); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TUint8.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(0, s.getByte())); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TUint8.class); + try (TUint8 result = sess.runner().fetch(op.asOutput()).run().single()) { + result.scalars().forEach(s -> assertEquals(0, s.getByte())); } } } @@ -119,9 +118,9 @@ public void createStringZeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TString.DTYPE); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TString.DTYPE)) { - result.data().scalars().forEach(s -> assertTrue(s.getObject().isEmpty())); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TString.class); + try (TString result = sess.runner().fetch(op.asOutput()).run().single()) { + result.scalars().forEach(s -> assertTrue(s.getObject().isEmpty())); } } } @@ -132,8 +131,10 @@ public void operationsComposingZerosAreCorrectlyNamed() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.DTYPE); - List> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); + Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.class); + try (TensorList results = + sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run()) { + } } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java index 87b24b0da2a..398b1693e14 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java @@ -21,90 +21,93 @@ import org.junit.jupiter.api.Test; import org.tensorflow.EagerSession; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.math.Add; import org.tensorflow.op.math.Sub; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.IntNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.index.Indices; import org.tensorflow.types.family.TNumber; -abstract class NumericTypesTestBase, U> { +abstract class NumericTypesTestBase { @Test public void initializeTensorsWithZeros() { // Allocate a tensor of 32-bits integer of the shape (2, 3, 2) - Tensor tensor = allocateTensor(Shape.of(2, 3, 2)); - NdArray tensorData = tensor.data(); + T tensor = allocateTensor(Shape.of(2, 3, 2)); - assertEquals(3, tensorData.rank()); - assertEquals(12, tensorData.size()); + assertEquals(3, tensor.rank()); + assertEquals(12, tensor.size()); + NdArray data = (NdArray)tensor; try (EagerSession session = EagerSession.create()) { Ops tf = Ops.create(session); // Initialize tensor memory with zeros and take a snapshot - tensorData.scalars().forEach(scalar -> scalar.setObject(valueOf(0))); + data.scalars().forEach(scalar -> ((NdArray)scalar).setObject(valueOf(0))); Constant x = tf.constant(tensor); // Initialize the same tensor memory with ones and take a snapshot - tensorData.scalars().forEach(scalar -> scalar.setObject(valueOf(1))); + data.scalars().forEach(scalar -> ((NdArray)scalar).setObject(valueOf(1))); Constant y = tf.constant(tensor); // Subtract y from x and validate the result Sub sub = tf.math.sub(x, y); - sub.data().scalars().forEach(scalar -> + ((NdArray)sub.asTensor()).scalars().forEach(scalar -> assertEquals(valueOf(-1), scalar.getObject()) ); } } @Test - public void genericTest() { - IntNdArray heapData = NdArrays.vectorOf(0, 1, 2, 3); + public void setAndCompute() { + NdArray heapData = allocateNdArray(Shape.of(4)) + .setObject(valueOf(0), 0) + .setObject(valueOf(1), 1) + .setObject(valueOf(2), 2) + .setObject(valueOf(3), 3); // Creates a 2x2 matrix - try (Tensor tensor = TInt32.tensorOf(Shape.of(2, 2))) { - IntNdArray tensorData = tensor.data(); + try (T tensor = allocateTensor(Shape.of(2, 2))) { + NdArray data = (NdArray)tensor; // Copy first 2 values of the vector to the first row of the matrix - tensorData.set(heapData.slice(Indices.range(0, 2)), 0); + data.set(heapData.slice(Indices.range(0, 2)), 0); // Copy values at an odd position in the vector as the second row of the matrix - tensorData.set(heapData.slice(Indices.odd()), 1); + data.set(heapData.slice(Indices.odd()), 1); - assertEquals(0, tensorData.getInt(0, 0)); - assertEquals(1, tensorData.getInt(0, 1)); - assertEquals(1, tensorData.getInt(1, 0)); - assertEquals(3, tensorData.getInt(1, 1)); + assertEquals(valueOf(0), data.getObject(0, 0)); + assertEquals(valueOf(1), data.getObject(0, 1)); + assertEquals(valueOf(1), data.getObject(1, 0)); + assertEquals(valueOf(3), data.getObject(1, 1)); // Read rows of the tensor in reverse order - IntNdArray reversedTensorData = tensorData.slice(Indices.all(), Indices.flip()); + NdArray flippedData = data.slice(Indices.flip(), Indices.flip()); - assertEquals(1, reversedTensorData.getInt(0, 0)); - assertEquals(0, reversedTensorData.getInt(0, 1)); - assertEquals(3, reversedTensorData.getInt(1, 0)); - assertEquals(1, reversedTensorData.getInt(1, 1)); + assertEquals(valueOf(3), flippedData.getObject(0, 0)); + assertEquals(valueOf(1), flippedData.getObject(0, 1)); + assertEquals(valueOf(1), flippedData.getObject(1, 0)); + assertEquals(valueOf(0), flippedData.getObject(1, 1)); try (EagerSession session = EagerSession.create()) { Ops tf = Ops.create(session); - // Compute the power of the tensor by itself - Constant x = tf.constant(tensor); - IntNdArray result = tf.math.pow(x, x).data(); + Add add = tf.math.add(tf.constant(tensor), tf.constant(tensor)); + NdArray result = (NdArray)add.asTensor(); - // Validate result by computing the same operation in Java - tensorData.scalars().forEachIndexed((coords, s) -> - assertEquals(Math.pow(s.getInt(), s.getInt()), result.getInt(coords), 1e-7f) - ); + assertEquals(valueOf(0), result.getObject(0, 0)); + assertEquals(valueOf(2), result.getObject(0, 1)); + assertEquals(valueOf(2), result.getObject(1, 0)); + assertEquals(valueOf(6), result.getObject(1, 1)); } } } - abstract Tensor allocateTensor(Shape shape); + abstract T allocateTensor(Shape shape); + + abstract NdArray allocateNdArray(Shape shape); abstract U valueOf(Integer value); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java index 8681e805e3d..17a6e0dd2b5 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java @@ -17,16 +17,22 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TBfloat16Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TBfloat16 allocateTensor(Shape shape) { return TBfloat16.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofFloats(shape); + } + @Override Float valueOf(Integer value) { return value.floatValue(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java index b72fe6fc01c..c1ae8ad3b6d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java @@ -17,16 +17,22 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TFloat16Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TFloat16 allocateTensor(Shape shape) { return TFloat16.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofFloats(shape); + } + @Override Float valueOf(Integer value) { return value.floatValue(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java index c4b1f6023f3..8df96f2871a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java @@ -17,16 +17,22 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TFloat32Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TFloat32 allocateTensor(Shape shape) { return TFloat32.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofFloats(shape); + } + @Override Float valueOf(Integer value) { return value.floatValue(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java index 0e9c8947d0f..47b4b6d936a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java @@ -17,16 +17,22 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TFloat64Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TFloat64 allocateTensor(Shape shape) { return TFloat64.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofDoubles(shape); + } + @Override Double valueOf(Integer value) { return value.doubleValue(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java index c52394bf210..9ea7f952f04 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java @@ -17,16 +17,24 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TInt32Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TInt32 allocateTensor(Shape shape) { return TInt32.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofInts(shape); + } + @Override Integer valueOf(Integer value) { return value; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java index 261ac546fc5..a88f3fb4d6d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java @@ -17,16 +17,22 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TInt64Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TInt64 allocateTensor(Shape shape) { return TInt64.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofLongs(shape); + } + @Override Long valueOf(Integer value) { return value.longValue(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java index a4700aa652f..08048947648 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java @@ -23,7 +23,6 @@ import java.nio.charset.StandardCharsets; import org.junit.jupiter.api.Test; -import org.tensorflow.Tensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; @@ -32,37 +31,28 @@ public class TStringTest { @Test public void createScalar() { - Tensor tensor = TString.scalarOf("Pretty vacant"); + TString tensor = TString.scalarOf("Pretty vacant"); assertNotNull(tensor); - - TString data = tensor.data(); - assertNotNull(data); - assertEquals(Shape.scalar(), data.shape()); - assertEquals("Pretty vacant", data.getObject()); + assertEquals(Shape.scalar(), tensor.shape()); + assertEquals("Pretty vacant", tensor.getObject()); } @Test public void createrScalarLongerThan127() { - Tensor tensor = TString.scalarOf("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !"); + TString tensor = TString.scalarOf("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !"); assertNotNull(tensor); - - TString data = tensor.data(); - assertNotNull(data); - assertEquals(Shape.scalar(), data.shape()); - assertEquals("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", data.getObject()); + assertEquals(Shape.scalar(), tensor.shape()); + assertEquals("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", tensor.getObject()); } @Test public void createVector() { - Tensor tensor = TString.vectorOf("Pretty", "vacant"); + TString tensor = TString.vectorOf("Pretty", "vacant"); assertNotNull(tensor); - - TString data = tensor.data(); - assertNotNull(data); - assertEquals(Shape.of(2), data.shape()); - assertEquals("Pretty", data.getObject(0)); - assertEquals("vacant", data.getObject(1)); + assertEquals(Shape.of(2), tensor.shape()); + assertEquals("Pretty", tensor.getObject(0)); + assertEquals("vacant", tensor.getObject(1)); } @Test @@ -73,30 +63,27 @@ public void createCopy() { .setObject("New", 1, 0) .setObject("York", 1, 1); - Tensor tensor = TString.tensorOf(strings); + TString tensor = TString.tensorOf(strings); assertNotNull(tensor); - - TString data = tensor.data(); - assertNotNull(data); strings.scalars().forEachIndexed((idx, s) -> - assertEquals(s.getObject(), data.getObject(idx)) + assertEquals(s.getObject(), tensor.getObject(idx)) ); } @Test public void defaultCharsetIsUtf8() { - Tensor tensor = TString.tensorOf(NdArrays.scalarOfObject(BABY_CHICK)); - byte[] bytes = tensor.data().asBytes().getObject(); + TString tensor = TString.tensorOf(NdArrays.scalarOfObject(BABY_CHICK)); + byte[] bytes = tensor.asBytes().getObject(); assertArrayEquals(new byte[] { (byte)0xF0, (byte)0x9F, (byte)0x90, (byte)0xA5 }, bytes); - assertEquals(BABY_CHICK, tensor.data().getObject()); + assertEquals(BABY_CHICK, tensor.getObject()); } @Test public void usingDifferentCharset() { - Tensor tensor = TString.tensorOf(StandardCharsets.UTF_16LE, NdArrays.scalarOfObject(BABY_CHICK)); - byte[] bytes = tensor.data().asBytes().getObject(); + TString tensor = TString.tensorOf(StandardCharsets.UTF_16LE, NdArrays.scalarOfObject(BABY_CHICK)); + byte[] bytes = tensor.asBytes().getObject(); assertArrayEquals(new byte[] { (byte)0x3D, (byte)0xD8, (byte)0x25, (byte)0xDC }, bytes); - assertEquals(BABY_CHICK, tensor.data().using(StandardCharsets.UTF_16LE).getObject()); + assertEquals(BABY_CHICK, tensor.using(StandardCharsets.UTF_16LE).getObject()); } @Test @@ -106,11 +93,11 @@ public void initializingTensorWithRawBytes() { for (int i = 0; i < strings.length; ++i) { bytes.setObject(strings[i].getBytes(), i); } - Tensor tensor = TString.tensorOfBytes(bytes); + TString tensor = TString.tensorOfBytes(bytes); assertNotNull(tensor); assertEquals(bytes.shape(), tensor.shape()); - NdArray tensorBytes = tensor.data().asBytes(); + NdArray tensorBytes = tensor.asBytes(); for (int i = 0; i < strings.length; ++i) { assertArrayEquals(bytes.getObject(i), tensorBytes.getObject(i)); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java index cc83087e018..ce7397d5878 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java @@ -17,16 +17,22 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TUint8Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TUint8 allocateTensor(Shape shape) { return TUint8.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofBytes(shape); + } + @Override Byte valueOf(Integer value) { return value.byteValue(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java index ae3d7e8c896..d421bc1155d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TBool; @@ -89,7 +88,7 @@ public Operand call(Operand input) { Operand result = tf.nn.elu(input); if (alpha == 1.0) return result; else { - DataType dataType = input.asOutput().dataType(); + Class dataType = input.asOutput().type(); Operand y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), dataType)); Operand cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), dataType)); return tf.select(cond, result, y); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java index a486cbdc601..9e254a8fa34 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; @@ -63,7 +62,7 @@ public HardSigmoid(Ops tf) { */ @Override public Operand call(Operand input) { - DataType dataType = input.asOutput().dataType(); + Class dataType = input.asOutput().type(); Operand point2 = tf.dtypes.cast(tf.constant(0.2), dataType); Operand point5 = tf.dtypes.cast(tf.constant(0.5), dataType); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java index c24cf71077d..d405aa1bc7d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.op.math.Greater; @@ -99,7 +98,7 @@ public ReLU(Ops tf, float alpha, float maxValue, float threshold) { @Override public Operand call(Operand input) { - DataType dataType = input.asOutput().dataType(); + Class dataType = input.asOutput().type(); boolean clipMax = !Float.isNaN(maxValue); Operand negativePart = null; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java index 007bcb01a40..c1eca6aeb29 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.data.impl.BatchDataset; import org.tensorflow.framework.data.impl.MapDataset; @@ -33,6 +32,7 @@ import java.util.Iterator; import java.util.List; import java.util.function.Function; +import org.tensorflow.types.family.TType; /** * Represents a potentially large list of independent elements (samples), and allows iteration and @@ -41,11 +41,11 @@ public abstract class Dataset implements Iterable>> { protected Ops tf; private Operand variant; - private List> outputTypes; + private List> outputTypes; private List outputShapes; public Dataset( - Ops tf, Operand variant, List> outputTypes, List outputShapes) { + Ops tf, Operand variant, List> outputTypes, List outputShapes) { if (tf == null) { throw new IllegalArgumentException("Ops accessor cannot be null."); } @@ -266,7 +266,7 @@ public DatasetIterator makeOneShotIterator() { * @return A new `Dataset` */ public static Dataset fromTensorSlices( - Ops tf, List> tensors, List> outputTypes) { + Ops tf, List> tensors, List> outputTypes) { return new TensorSliceDataset(tf, tensors, outputTypes); } @@ -288,7 +288,7 @@ public Operand getVariant() { } /** Get a list of output types for each component of this dataset. */ - public List> getOutputTypes() { + public List> getOutputTypes() { return this.outputTypes; } @@ -305,7 +305,7 @@ public Ops getOpsInstance() { public String toString() { return "Dataset{" + "outputTypes=" - + Arrays.toString(getOutputTypes().stream().map(DataType::name).toArray()) + + Arrays.toString(getOutputTypes().stream().map(Class::getSimpleName).toArray()) + ", outputShapes=" + Arrays.toString(getOutputShapes().stream().map(Shape::toString).toArray()) + "}"; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java index f4c4b681715..5dbb2bfacff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data; -import org.tensorflow.DataType; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.op.Op; @@ -25,6 +24,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import org.tensorflow.types.family.TType; /** * Represents the state of an iteration through a tf.data Datset. DatasetIterator is not a @@ -106,7 +106,7 @@ public class DatasetIterator implements Iterable>> { private Operand iteratorResource; private Op initializer; - protected List> outputTypes; + protected List> outputTypes; protected List outputShapes; /** @@ -124,7 +124,7 @@ public DatasetIterator( Ops tf, Operand iteratorResource, Op initializer, - List> outputTypes, + List> outputTypes, List outputShapes) { this.tf = tf; @@ -137,7 +137,7 @@ public DatasetIterator( public DatasetIterator( Ops tf, Operand iteratorResource, - List> outputTypes, + List> outputTypes, List outputShapes) { this.tf = tf; this.iteratorResource = iteratorResource; @@ -236,7 +236,7 @@ public Op makeInitializer(Dataset dataset) { * @return A new DatasetIterator */ public static DatasetIterator fromStructure( - Ops tf, List> outputTypes, List outputShapes) { + Ops tf, List> outputTypes, List outputShapes) { Operand iteratorResource = tf.scope().env() instanceof Graph ? tf.data.iterator(EMPTY_SHARED_NAME, "", outputTypes, outputShapes) @@ -272,7 +272,7 @@ public Iterator>> iterator() { @Override public boolean hasNext() { - return nextOptional.hasValue().data().getBoolean(); + return nextOptional.hasValue().asTensor().getBoolean(); } @Override diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java index 925252c7298..6617c33eaf7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.ndarray.Shape; @@ -23,6 +22,7 @@ import java.util.ArrayList; import java.util.List; +import org.tensorflow.types.family.TType; /** * An optional represents the result of a dataset getNext operation that may fail, when the end of @@ -36,11 +36,11 @@ public Operand getOptionalVariant() { } private Operand optionalVariant; - private List> outputTypes; + private List> outputTypes; private List outputShapes; public DatasetOptional( - Ops tf, Operand optionalVariant, List> outputTypes, List outputShapes) { + Ops tf, Operand optionalVariant, List> outputTypes, List outputShapes) { this.tf = tf; this.optionalVariant = optionalVariant; this.outputTypes = outputTypes; @@ -75,7 +75,7 @@ public List> getValue() { public static DatasetOptional fromComponents( Ops tf, List> components, - List> outputTypes, + List> outputTypes, List outputShapes) { Operand optionalVariant = tf.data.optionalFromValue(components); return new DatasetOptional(tf, optionalVariant, outputTypes, outputShapes); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/BatchDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/BatchDataset.java index 277b049cf6f..f0561b2e61e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/BatchDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/BatchDataset.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data.impl; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; import org.tensorflow.op.Ops; @@ -25,6 +24,7 @@ import org.tensorflow.types.TInt64; import java.util.List; +import org.tensorflow.types.family.TType; public class BatchDataset extends Dataset { public BatchDataset( @@ -32,7 +32,7 @@ public BatchDataset( Operand variant, Constant batchSize, Constant dropRemainder, - List> outputTypes, + List> outputTypes, List outputShapes) { super( tf, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/SkipDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/SkipDataset.java index 6731bac60b3..63b4208480b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/SkipDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/SkipDataset.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data.impl; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; import org.tensorflow.op.Ops; @@ -24,6 +23,7 @@ import org.tensorflow.types.TInt64; import java.util.List; +import org.tensorflow.types.family.TType; public class SkipDataset extends Dataset { @@ -31,7 +31,7 @@ public SkipDataset( Ops tf, Operand variant, Constant count, - List> outputTypes, + List> outputTypes, List outputShapes) { super( tf, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TFRecordDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TFRecordDataset.java index ed721b13ebf..00297152e90 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TFRecordDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TFRecordDataset.java @@ -34,7 +34,7 @@ public TFRecordDataset( super( tf, tf.data.tfRecordDataset(filenames, compressionType, bufferSize), - Collections.singletonList(TString.DTYPE), + Collections.singletonList(TString.class), Collections.singletonList(Shape.scalar())); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TakeDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TakeDataset.java index 08c57d44a73..39ca9759e74 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TakeDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TakeDataset.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data.impl; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; import org.tensorflow.op.Ops; @@ -24,6 +23,7 @@ import org.tensorflow.types.TInt64; import java.util.List; +import org.tensorflow.types.family.TType; public class TakeDataset extends Dataset { @@ -31,7 +31,7 @@ public TakeDataset( Ops tf, Operand variant, Constant count, - List> outputTypes, + List> outputTypes, List outputShapes) { super( tf, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java index 14405ebdaf5..eac0f1cc1ed 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data.impl; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; import org.tensorflow.op.Ops; @@ -23,10 +22,11 @@ import java.util.List; import java.util.stream.Collectors; +import org.tensorflow.types.family.TType; public class TensorSliceDataset extends Dataset { - public TensorSliceDataset(Ops tf, List> components, List> outputTypes) { + public TensorSliceDataset(Ops tf, List> components, List> outputTypes) { super(tf, makeVariant(tf, components, outputTypes), outputTypes, outputShapes(components)); } @@ -35,7 +35,7 @@ private static List outputShapes(List> components) { } private static Operand makeVariant( - Ops tf, List> components, List> outputTypes) { + Ops tf, List> components, List> outputTypes) { if (!(components.size() == outputTypes.size())) { throw new IllegalArgumentException( "Lists `tensors` and `dtypes` must have the same number of elements."); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TextLineDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TextLineDataset.java index 4ef47825211..c9a26304778 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TextLineDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TextLineDataset.java @@ -34,7 +34,7 @@ public TextLineDataset( super( tf, tf.data.textLineDataset(filenames, compressionType, bufferSize), - Collections.singletonList(TString.DTYPE), + Collections.singletonList(TString.class), Collections.singletonList(Shape.scalar())); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java index b4544de9bd0..16aed3db6ed 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java @@ -14,10 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; /** @@ -85,17 +86,17 @@ public Constant(Ops tf, boolean value) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!(dtype.isNumeric() || dtype.isBoolean())) { - throw new IllegalArgumentException("DataType must be numeric or boolean: " + dtype.name()); + public Operand call(Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { + throw new IllegalArgumentException("DataType must be numeric or boolean: " + type.getSimpleName()); } switch (valueType) { case LONG: - return tf.fill(dims, tf.dtypes.cast(tf.constant(longValue), dtype)); + return tf.fill(dims, tf.dtypes.cast(tf.constant(longValue), type)); case DOUBLE: - return tf.fill(dims, tf.dtypes.cast(tf.constant(doubleValue), dtype)); + return tf.fill(dims, tf.dtypes.cast(tf.constant(doubleValue), type)); default: - return tf.fill(dims, tf.dtypes.cast(tf.constant(booleanValue), dtype)); + return tf.fill(dims, tf.dtypes.cast(tf.constant(booleanValue), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java index 3d5d37b91d3..fd6b8542003 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java @@ -16,6 +16,7 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -70,7 +71,7 @@ * @see VarianceScaling.Distribution * @see Glorot et al., 2010 */ -public class Glorot extends VarianceScaling { +public class Glorot extends VarianceScaling { public static final double SCALE = 1.0; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java index ce99da80bf7..5c0f3bbc8cb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -66,7 +67,7 @@ * href="https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html">He * et al., 2015 */ -public class He extends VarianceScaling { +public class He extends VarianceScaling { public static final double SCALE = 2.0; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java index 34e6cd790f4..9f462dcfc5c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java @@ -14,12 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.utils.ShapeUtils; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TType; /** @@ -38,7 +38,7 @@ * * @param The TType for the call operation */ -public class Identity extends BaseInitializer { +public class Identity extends BaseInitializer { public static final double GAIN_DEFAULT = 1.0; private final double gain; @@ -66,10 +66,7 @@ public Identity(Ops tf, double gain) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!dtype.isFloating()) { - throw new IllegalArgumentException("DataType must be a float type: " + dtype.name()); - } + public Operand call(Operand dims, Class type) { Shape shape = ShapeUtils.toShape(tf.scope(), dims); if (shape.numDimensions() != 2) { throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions()); @@ -79,9 +76,9 @@ public Operand call(Operand dims, DataType dtype) { Shape diagShape = Shape.of(diagSize); Operand op; - Operand zero = tf.dtypes.cast(tf.constant(0), dtype); + Operand zero = tf.dtypes.cast(tf.constant(0), type); Operand diagOnes = - tf.fill(tf.constant(diagShape.asArray()), tf.dtypes.cast(tf.constant(1.0), dtype)); + tf.fill(tf.constant(diagShape.asArray()), tf.dtypes.cast(tf.constant(1.0), type)); if (isSquare) { op = tf.linalg.matrixDiag( @@ -91,10 +88,10 @@ public Operand call(Operand dims, DataType dtype) { tf.constant((int) shape.size(1)), zero); } else { - Operand zeroMatrix = tf.zeros(dims, dtype); + Operand zeroMatrix = tf.zeros(dims, type); op = tf.linalg.matrixSetDiag(zeroMatrix, diagOnes, tf.constant(0)); } - return tf.math.mul(op, tf.dtypes.cast(tf.constant(gain), dtype)); + return tf.math.mul(op, tf.dtypes.cast(tf.constant(gain), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java index 59dce1fc02e..4beb218783b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -30,8 +29,8 @@ public interface Initializer { * Generates the operation used to perform the initialization. * * @param dims the shape dimensions - * @param dtype the data type + * @param type the type of tensor * @return An operand for the initialization. */ - Operand call(Operand dims, DataType dtype); + Operand call(Operand dims, Class type); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java index e2268412fc3..0f29f21284a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -76,7 +77,7 @@ * al., 1998 * @see VarianceScaling.Distribution */ -public class LeCun extends VarianceScaling { +public class LeCun extends VarianceScaling { /** * Creates a LeCunNormal Initializer diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java index b78f34e3d35..ef33d37bcb0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java @@ -14,10 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; /** @@ -56,10 +57,10 @@ public Ones(Ops tf) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!(dtype.isNumeric() || dtype.isBoolean())) { - throw new IllegalArgumentException("DataType must be numeric or boolean: " + dtype.name()); + public Operand call(Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { + throw new IllegalArgumentException("DataType must be numeric or boolean: " + type.getSimpleName()); } - return tf.fill(dims, tf.dtypes.cast(tf.constant(1.0), dtype)); + return tf.fill(dims, tf.dtypes.cast(tf.constant(1.0), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java index 48e2c56d5be..1aada14d3c4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.framework.utils.ShapeUtils; @@ -22,6 +21,7 @@ import org.tensorflow.op.Ops; import org.tensorflow.op.linalg.Qr; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -50,7 +50,7 @@ * @param The TType for the call operation * @param The TNumber for the call operation */ -public class Orthogonal extends BaseInitializer { +public class Orthogonal extends BaseInitializer { public static final double GAIN_DEFAULT = 1.0; @@ -84,10 +84,7 @@ public Orthogonal(Ops tf, double gain, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!dtype.isFloating()) { - throw new IllegalArgumentException("Expected floating point type, got " + dtype.name()); - } + public Operand call(Operand dims, Class type) { Shape dimsShape = ShapeUtils.toShape(tf.scope(), dims); if (dimsShape.numDimensions() < 2) { throw new IllegalArgumentException( @@ -100,22 +97,17 @@ public Operand call(Operand dims, DataType dtype) { long numCols = dimsShape.size(i); Shape flatShape = Shape.of(Math.max(numRows, numCols), Math.min(numRows, numCols)); long[] seeds = {seed, 0}; - @SuppressWarnings("unchecked") - DataType numdType = (DataType) dtype; - @SuppressWarnings("unchecked") Operand op = - (Operand) - tf.random.statelessRandomNormal(tf.constant(flatShape), tf.constant(seeds), numdType); - + tf.random.statelessRandomNormal(tf.constant(flatShape), tf.constant(seeds), type); Qr.Options qrOptions = Qr.fullMatrices(false); Qr qrOp = tf.linalg.qr(op, qrOptions); Output qo = qrOp.q(); Output ro = qrOp.r(); Operand diagOp = - tf.linalg.matrixDiagPart(ro, tf.constant(0), tf.dtypes.cast(tf.constant(0), dtype)); + tf.linalg.matrixDiagPart(ro, tf.constant(0), tf.dtypes.cast(tf.constant(0), type)); Operand qop = tf.math.mul(qo, tf.math.sign(diagOp)); if (numRows < numCols) qop = tf.linalg.transpose(qop, null); - return tf.math.mul(qop, tf.dtypes.cast(tf.constant(this.gain), dtype)); + return tf.math.mul(qop, tf.dtypes.cast(tf.constant(this.gain), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java index f2d8a0d8e6e..5c2311d7439 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; @@ -37,7 +36,7 @@ * @param The TType for the call operation * @param The TNumber for the call operation */ -public class RandomNormal extends BaseInitializer { +public class RandomNormal extends BaseInitializer { public static final double MEAN_DEFAULT = 0.0; public static final double STDDEV_DEFAULT = 1.0; @@ -87,16 +86,10 @@ public RandomNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!dtype.isNumeric()) - throw new IllegalArgumentException("The data type must be numeric. Found : " + dtype.name()); + public Operand call(Operand dims, Class type) { long[] seeds = {seed, 0}; - @SuppressWarnings("unchecked") - DataType numdType = (DataType) dtype; - @SuppressWarnings("unchecked") - Operand distOp = - (Operand) tf.random.statelessRandomNormal(dims, tf.constant(seeds), numdType); - Operand op = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.stddev), dtype)); - return tf.math.add(op, tf.dtypes.cast(tf.constant(mean), dtype)); + Operand distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); + Operand op = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.stddev), type)); + return tf.math.add(op, tf.dtypes.cast(tf.constant(mean), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java index b665729675d..77c1927bdfa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java @@ -14,11 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.op.random.RandomUniformInt; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TInteger; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -38,7 +38,7 @@ * @param The TType for the call operation * @param The TNumber for the call operation */ -public class RandomUniform extends BaseInitializer { +public class RandomUniform extends BaseInitializer { public static final double MINVAL_DEFAULT = -0.05; public static final double MAXVAL_DEFAULT = 0.05; @@ -77,39 +77,28 @@ public RandomUniform(Ops tf, double minval, double maxval, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!dtype.isNumeric()) - throw new IllegalArgumentException("The data type must be numeric. Found : " + dtype.name()); - @SuppressWarnings("unchecked") - DataType numdType = (DataType) dtype; - Operand distOp; - - if (dtype.isInteger()) { + public Operand call(Operand dims, Class type) { + Operand distOp; + if (TInteger.class.isAssignableFrom(type)) { RandomUniformInt.Options options = RandomUniformInt.seed(this.seed); distOp = tf.random.randomUniformInt( dims, - tf.dtypes.cast(tf.constant(this.minval), numdType), - tf.dtypes.cast(tf.constant(this.maxval), numdType), + tf.dtypes.cast(tf.constant(this.minval), type), + tf.dtypes.cast(tf.constant(this.maxval), type), options); - @SuppressWarnings("unchecked") - Operand distOpT = (Operand) distOp; - return distOpT; } else { long[] seeds = {seed, 0}; - distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), numdType); - @SuppressWarnings("unchecked") - Operand distOpT = (Operand) distOp; + distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); if (this.minval == 0) { if (this.maxval != 1.0) { - distOpT = tf.math.mul(distOpT, tf.dtypes.cast(tf.constant(this.maxval), dtype)); + distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval), type)); } } else { - distOpT = - tf.math.mul(distOpT, tf.dtypes.cast(tf.constant(this.maxval - this.minval), dtype)); - distOpT = tf.math.add(distOpT, tf.dtypes.cast(tf.constant(this.minval), dtype)); + distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval - this.minval), type)); + distOp = tf.math.add(distOp, tf.dtypes.cast(tf.constant(this.minval), type)); } - return distOpT; } + return distOp; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java index c71cf9a630e..c38c073f282 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; @@ -37,7 +36,7 @@ * @param The TType for the call operation * @param The TNumber for the call operation */ -public class TruncatedNormal extends BaseInitializer { +public class TruncatedNormal extends BaseInitializer { public static final double MEAN_DEFAULT = 0.0; public static final double STDDEV_DEFAULT = 0.05; @@ -76,17 +75,11 @@ public TruncatedNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!dtype.isNumeric()) - throw new IllegalArgumentException("The data type must be numeric. Found : " + dtype.name()); + public Operand call(Operand dims, Class type) { long[] seeds = {seed,0}; - @SuppressWarnings("unchecked") - DataType numdType = (DataType) dtype; - Operand distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), numdType); - @SuppressWarnings("unchecked") - Operand distOpT = (Operand) distOp; + Operand distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); return tf.math.add( - tf.math.mul(distOpT, tf.dtypes.cast(tf.constant(stddev), dtype)), - tf.dtypes.cast(tf.constant(mean), dtype)); + tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)), + tf.dtypes.cast(tf.constant(mean), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java index fd33adadd5c..fdf8920ce1b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java @@ -14,12 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.utils.ShapeUtils; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -58,7 +58,7 @@ * @see VarianceScaling.Mode * @see VarianceScaling.Distribution */ -public class VarianceScaling extends BaseInitializer { +public class VarianceScaling extends BaseInitializer { public static final double SCALE_DEFAULT = 1.0; public static final Mode MODE_DEFAULT = Mode.FAN_IN; @@ -102,10 +102,7 @@ public VarianceScaling(Ops tf, double scale, Mode mode, Distribution distributio /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!dtype.isFloating()) { - throw new IllegalArgumentException("Expected floating point type, got " + dtype.name()); - } + public Operand call(Operand dims, Class type) { Shape shape = ShapeUtils.toShape(this.tf.scope(), dims); double lscale = this.scale; double[] fans /* fanIn, fanOut */ = computeFans(shape); @@ -120,32 +117,28 @@ public Operand call(Operand dims, DataType dtype) { lscale /= Math.max(1., (fans[0] + fans[1]) / 2.); break; } - Operand distOp; - Operand mulOp = null; - @SuppressWarnings("unchecked") - DataType numdType = (DataType) dtype; + Operand distOp; + Operand mulOp = null; double stddev; long[] seeds = {seed, 0}; switch (distribution) { case TRUNCATED_NORMAL: - distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), numdType); + distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); stddev = Math.sqrt(lscale) / .87962566103423978; - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), numdType)); + mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); break; case NORMAL: - distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), numdType); + distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); stddev = Math.sqrt(lscale); - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), numdType)); + mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); break; case UNIFORM: - distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), numdType); + distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); stddev = Math.sqrt(3.0 * lscale); - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), numdType)); + mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); break; } - - // Need to cast TNumber to TType - return (Operand) mulOp; + return mulOp; } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java index 09dd512ffaa..1582750b37b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; @@ -46,7 +45,7 @@ public Zeros(Ops tf) { } @Override - public Operand call(Operand dims, DataType dtype) { + public Operand call(Operand dims, Class dtype) { return tf.zeros(dims, dtype); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index effdf990f71..88ab71eb4be 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -217,8 +217,8 @@ public Operand call( getTF(), "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()), - cast(getTF(), getTF().constant(1), predictions.asOutput().dataType())); + cast(getTF(), getTF().constant(0), predictions.asOutput().type()), + cast(getTF(), getTF().constant(1), predictions.asOutput().type())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 7701ebfb806..d7bb4ab5128 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -256,8 +256,8 @@ public Operand call( getTF(), "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()), - cast(getTF(), getTF().constant(1), predictions.asOutput().dataType())); + cast(getTF(), getTF().constant(0), predictions.asOutput().type()), + cast(getTF(), getTF().constant(1), predictions.asOutput().type())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 5fdfd4c9b96..b5bf2d32355 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -124,15 +124,15 @@ public Hinge(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ? + Operand tLabels = predictions.asOutput().type() == labels.asOutput().type() ? (Operand)labels : - cast(tf, labels, predictions.asOutput().dataType()); + cast(tf, labels, predictions.asOutput().type()); tLabels = LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), - predictions.asOutput().dataType())); + predictions.asOutput().type())); Operand losses = Losses.hinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 7a633ede2bf..1d5bbbf41e5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.losses; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; @@ -51,7 +50,7 @@ public class Losses { */ public static Operand meanAbsoluteError( Ops tf, Operand labels, Operand predictions) { - Operand tLabels = cast(tf, labels, predictions.asOutput().dataType()); + Operand tLabels = cast(tf, labels, predictions.asOutput().type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -73,7 +72,7 @@ public static Operand meanAbsoluteErro */ public static Operand meanSquaredError( Ops tf, Operand labels, Operand predictions) { - Operand tLabels = cast(tf, labels, predictions.asOutput().dataType()); + Operand tLabels = cast(tf, labels, predictions.asOutput().type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -94,7 +93,7 @@ public static Operand meanSquaredError */ public static Operand meanAbsolutePercentageError( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand tLabels = cast(tf, labels, dataType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -121,7 +120,7 @@ public static Operand meanAbsolutePerc */ public static Operand meanSquaredLogarithmicError( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand tLabels = cast(tf, labels, dataType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -152,7 +151,7 @@ public static Operand meanSquaredLogar */ public static Operand binaryCrossentropy( Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand tLabels = cast(tf, labels, dataType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -193,7 +192,7 @@ private static Operand binaryCrossentropyHelper( } */ - DataType dataType = output.asOutput().dataType(); + Class dataType = output.asOutput().type(); Operand one = cast(tf, tf.constant(1), dataType); Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); @@ -231,7 +230,7 @@ public static Operand categoricalCross boolean fromLogits, float labelSmoothing, int axis) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand tLabels = cast(tf, labels, dataType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -284,7 +283,7 @@ public static Operand categoricalCross */ public static Operand categoricalHinge( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -330,7 +329,7 @@ public static Operand categoricalHinge */ public static Operand cosineSimilarity( Ops tf, Operand labels, Operand predictions, int axis) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -357,7 +356,7 @@ public static Operand cosineSimilarity */ public static Operand hinge( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -393,7 +392,7 @@ public static Operand hinge( */ public static Operand huber( Ops tf, Operand labels, Operand predictions, float delta) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -424,7 +423,7 @@ public static Operand huber( */ public static Operand kullbackLeiblerDivergence( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -454,7 +453,7 @@ public static Operand kullbackLeiblerD */ public static Operand logCosh( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -482,7 +481,7 @@ public static Operand logCosh( */ public static Operand poisson( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -509,7 +508,7 @@ public static Operand poisson( */ public static Operand sparseCategoricalCrossentropy( Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); Operand one = cast(tf, tf.constant(1), dataType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); @@ -546,7 +545,7 @@ public static Operand sparseCategorica predictions = tf.linalg.transpose(predictions, tf.constant(axisNew)); } - Operand iLabels = cast(tf, labels, TInt64.DTYPE); + Operand iLabels = cast(tf, labels, TInt64.class); // Try to adjust the shape so that rank of labels = rank of logits - 1. Shape labelsShape = labels.asOutput().shape(); @@ -586,7 +585,7 @@ public static Operand sparseCategorica */ public static Operand squaredHinge( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); + Class dataType = predictions.asOutput().type(); Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -614,7 +613,7 @@ public static Operand squaredHinge( */ private static Operand smoothBinaryLabels( Ops tf, Operand labels, float labelSmoothing) { - DataType dataType = labels.asOutput().dataType(); + Class dataType = labels.asOutput().type(); Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), dataType); Operand halfSmoothing = cast(tf, tf.constant(0.5F * labelSmoothing), dataType); return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), halfSmoothing); @@ -633,7 +632,7 @@ private static Operand smoothBinaryLabels( */ private static Operand smoothCategoricalLabels( Ops tf, Operand labels, float labelSmoothing) { - DataType dataType = labels.asOutput().dataType(); + Class dataType = labels.asOutput().type(); Operand smoothing = cast(tf, tf.constant(labelSmoothing), dataType); Shape labelsShape = labels.asOutput().shape(); int numDims = labelsShape.numDimensions(); @@ -656,7 +655,7 @@ public static Operand l2Normalize(Ops tf, Operand x, i tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = tf.math.rsqrt( - tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.asOutput().dataType()))); + tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.asOutput().type()))); return tf.math.mul(x, invNorm); } @@ -669,7 +668,7 @@ public static Operand l2Normalize(Ops tf, Operand x, i * @return the labels, possibly converted into -1/1. */ private static Operand maybeConvertLabels(Ops tf, Operand labels) { - DataType dataType = labels.asOutput().dataType(); + Class dataType = labels.asOutput().type(); Operand one = cast(tf, tf.constant(1), dataType); Operand zero = cast(tf, tf.constant(0), dataType); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 5586a4da889..9372126d3ab 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -205,8 +205,8 @@ public Operand call( getTF(), "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()), - cast(getTF(), getTF().constant(1), predictions.asOutput().dataType())); + cast(getTF(), getTF().constant(0), predictions.asOutput().type()), + cast(getTF(), getTF().constant(1), predictions.asOutput().type())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 182ce592e55..9b712dcc712 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -125,15 +125,15 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ? + Operand tLabels = predictions.asOutput().type() == labels.asOutput().type() ? (Operand)labels : - cast(tf, labels, predictions.asOutput().dataType()); + cast(tf, labels, predictions.asOutput().type()); tLabels = LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), - predictions.asOutput().dataType())); + predictions.asOutput().type())); Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 463296a1f50..1a9f1339e24 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.losses.impl; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Reduction; import org.tensorflow.ndarray.Shape; @@ -249,7 +248,7 @@ public static LossTuple removeSqueezableDimensions( */ public static Operand computeWeightedLoss( Ops tf, Operand loss, Reduction reduction, Operand sampleWeight) { - DataType dataType = loss.asOutput().dataType(); + Class dataType = loss.asOutput().type(); if (sampleWeight == null) { sampleWeight = cast(tf, tf.constant(1), dataType); } @@ -300,7 +299,7 @@ public static Operand safeMean( Ops tf, Operand losses, long numElements) { Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses)); return tf.math.divNoNan( - totalLoss, cast(tf, tf.constant(numElements), losses.asOutput().dataType())); + totalLoss, cast(tf, tf.constant(numElements), losses.asOutput().type())); } /** @@ -361,7 +360,7 @@ public static Operand rangeCheck( tf.withSubScope("rangeCheck") .withControlDependencies(Collections.singletonList(assertThat)); return ltf.identity(values); - } else if (!cond.asOutput().data().getBoolean()) + } else if (!cond.asOutput().asTensor().getBoolean()) throw new IllegalArgumentException(String.format("%s : values out of range", prefix)); else return values; } @@ -386,7 +385,7 @@ public static Operand valueCheck( Ops tf, String prefix, Operand values, Operand allowedValues) { Operand flatValues = tf.reshape(values, tf.constant(Shape.of(values.asOutput().shape().size()))); - SetDiff1d diff = tf.setDiff1d(flatValues, allowedValues, TInt32.DTYPE); + SetDiff1d diff = tf.setDiff1d(flatValues, allowedValues, TInt32.class); long diffSize = diff.out().asOutput().shape().size(); if (diffSize != Shape.UNKNOWN_SIZE) { @@ -409,7 +408,7 @@ public static Operand valueCheck( tf.withSubScope("valueCheck") .withControlDependencies(Collections.singletonList(assertThat)); return ltf.identity(values); - } else if (!cond.asOutput().data().getBoolean()) + } else if (!cond.asOutput().asTensor().getBoolean()) throw new IllegalArgumentException(String.format("%s : values not in value set", prefix)); else return values; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index 0adf5f58910..822eb490f22 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -141,10 +141,10 @@ protected void createSlots(List> variables) { */ private void createAdaDeltaSlot(Output v) { Operand accumulatorInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), ACCUMULATOR, accumulatorInitializer); Operand updateInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), ACCUMULATOR_UPDATE, updateInitializer); } @@ -157,9 +157,9 @@ protected Op applyDense(Output gradient, Output variable variable, accumSlot, accumUpdateSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), - tf.dtypes.cast(tf.constant(rho), gradient.dataType()), - tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), + tf.dtypes.cast(tf.constant(rho), gradient.type()), + tf.dtypes.cast(tf.constant(epsilon), gradient.type()), gradient); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 1a7f4675662..08f5f18a9cd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -131,7 +131,7 @@ protected void createSlots(List> variables) { */ private void createAdaGradSlot(Output v) { Operand initializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.type())); createSlot(v.asOutput(), ACCUMULATOR, initializer); } @@ -140,7 +140,7 @@ private void createAdaGradSlot(Output v) { protected Op applyDense(Output gradient, Output variable) { Variable slot = getSlot(variable, ACCUMULATOR).get(); return tf.train.applyAdagrad( - variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), gradient); + variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index f76217fda85..df624e41c4e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -187,7 +187,7 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdaGradDASlot(v); } - globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(), TInt64.DTYPE); + globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(), TInt64.class); Assign globalStepInitializer = tf.assign(globalStep, tf.constant(0L)); graph.addInitializer(globalStepInitializer); } @@ -199,10 +199,10 @@ protected void createSlots(List> variables) { * @param the datatype of the variable. */ private void createAdaGradDASlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), ACCUMULATOR, initializer); Operand sqInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.type())); createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer); } @@ -216,9 +216,9 @@ protected Op applyDense(Output gradient, Output variable gradSlot, gradSquaredSlot, gradient, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), - tf.dtypes.cast(tf.constant(l1Strength), gradient.dataType()), - tf.dtypes.cast(tf.constant(l2Strength), gradient.dataType()), + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), + tf.dtypes.cast(tf.constant(l1Strength), gradient.type()), + tf.dtypes.cast(tf.constant(l2Strength), gradient.type()), globalStep); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java index 8f620678781..72598d12543 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java @@ -189,10 +189,10 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdamSlot(v.asOutput()); } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); + betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class); Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); graph.addInitializer(betaOnePowerInit); - betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.DTYPE); + betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.class); Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo)); graph.addInitializer(betaTwoPowerInit); } @@ -215,10 +215,10 @@ protected Optional prepare(String scopeName) { */ private void createAdamSlot(Output v) { Operand firstMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); Operand secondMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); } @@ -231,12 +231,12 @@ protected Op applyDense(Output gradient, Output variable variable, firstMomentSlot, secondMomentSlot, - tf.dtypes.cast(betaOnePower, gradient.dataType()), - tf.dtypes.cast(betaTwoPower, gradient.dataType()), - tf.dtypes.cast(learningRateConst, gradient.dataType()), - tf.dtypes.cast(betaOneConst, gradient.dataType()), - tf.dtypes.cast(betaTwoConst, gradient.dataType()), - tf.dtypes.cast(epsilonConst, gradient.dataType()), + tf.dtypes.cast(betaOnePower, gradient.type()), + tf.dtypes.cast(betaTwoPower, gradient.type()), + tf.dtypes.cast(learningRateConst, gradient.type()), + tf.dtypes.cast(betaOneConst, gradient.type()), + tf.dtypes.cast(betaTwoConst, gradient.type()), + tf.dtypes.cast(epsilonConst, gradient.type()), gradient); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index 335d83cedfa..cd95bb3bd07 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -137,7 +137,7 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdamaxSlot(v.asOutput()); } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); + betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class); Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); ((Graph) tf.scope().env()).addInitializer(betaOnePowerInit); } @@ -150,10 +150,10 @@ protected void createSlots(List> variables) { */ private void createAdamaxSlot(Output v) { Operand firstMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); Operand secondMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); } @@ -167,11 +167,11 @@ protected Op applyDense(Output gradient, Output variable variable, firstMomentSlot, secondMomentSlot, - tf.dtypes.cast(betaOnePower, gradient.dataType()), - tf.dtypes.cast(learningRateConst, gradient.dataType()), - tf.dtypes.cast(betaOneConst, gradient.dataType()), - tf.dtypes.cast(betaTwoConst, gradient.dataType()), - tf.dtypes.cast(epsilonConst, gradient.dataType()), + tf.dtypes.cast(betaOnePower, gradient.type()), + tf.dtypes.cast(learningRateConst, gradient.type()), + tf.dtypes.cast(betaOneConst, gradient.type()), + tf.dtypes.cast(betaTwoConst, gradient.type()), + tf.dtypes.cast(epsilonConst, gradient.type()), gradient); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index 04c34a2535e..66314d2ffe0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -230,10 +230,10 @@ protected void createSlots(List> variables) { */ private void createFtrlSlot(Output v) { Operand initializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.type())); createSlot(v.asOutput(), ACCUMULATOR, initializer); Operand linearInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), LINEAR_ACCUMULATOR, linearInitializer); } @@ -248,12 +248,12 @@ protected Op applyDense(Output gradient, Output variable accumSlot, // accum linearSlot, // linear gradient, // gradient - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), // lr - tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.dataType()), // l1 - tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.dataType()), // l2 + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), // lr + tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.type()), // l1 + tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.type()), // l2 tf.dtypes.cast( - tf.constant(l2ShrinkageRegularizationStrength), gradient.dataType()), // l2Shrinkage - tf.dtypes.cast(tf.constant(learningRatePower), gradient.dataType()), // lrPower + tf.constant(l2ShrinkageRegularizationStrength), gradient.type()), // l2Shrinkage + tf.dtypes.cast(tf.constant(learningRatePower), gradient.type()), // lrPower options); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java index e307855e636..a373b2e5b55 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java @@ -66,7 +66,7 @@ public GradientDescent(Graph graph, String name, float learningRate) { @Override protected Op applyDense(Output gradient, Output variable) { return tf.train.applyGradientDescent( - variable, tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), gradient); + variable, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index 111727d26fa..f6640409d60 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -125,7 +125,7 @@ protected void createSlots(List> variables) { * @param the data type of the variable */ private void createMomentumSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MOMENTUM, initializer); } @@ -136,9 +136,9 @@ protected Op applyDense(Output gradient, Output variable return tf.train.applyMomentum( variable, slot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient, - tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), + tf.dtypes.cast(tf.constant(momentum), gradient.type()), ApplyMomentum.useNesterov(useNesterov)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index 48e5135c952..f9900a8ee78 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -1,6 +1,5 @@ package org.tensorflow.framework.optimizers; -import org.tensorflow.DataType; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; @@ -142,15 +141,15 @@ protected void createSlots(List> variables) { for (Output v : variables) { createNadamSlot(v.asOutput()); } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); + betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class); Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); ((Graph) tf.scope().env()).addInitializer(betaOnePowerInit); - betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.DTYPE); + betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.class); Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo)); ((Graph) tf.scope().env()).addInitializer(betaTwoPowerInit); - momentum = tf.withName("momentum").variable(Shape.scalar(), TFloat32.DTYPE); + momentum = tf.withName("momentum").variable(Shape.scalar(), TFloat32.class); Assign momentumInit = tf.assign(momentum, tf.constant(1.0F)); ((Graph) tf.scope().env()).addInitializer(momentumInit); } @@ -163,14 +162,14 @@ protected void createSlots(List> variables) { */ private void createNadamSlot(Output v) { Operand firstMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); Operand secondMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); Operand momentumInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); } @@ -198,7 +197,7 @@ protected Optional prepare(String scopeName) { point5, tf.math.pow( decayBaseConst, - tf.math.mul(decayConst, tf.dtypes.cast(localStepConst, TFloat32.DTYPE)))))); + tf.math.mul(decayConst, tf.dtypes.cast(localStepConst, TFloat32.class)))))); mT1 = tf.math.mul( @@ -209,7 +208,7 @@ protected Optional prepare(String scopeName) { point5, tf.math.pow( decayBaseConst, - tf.math.mul(decayConst, tf.dtypes.cast(nextStepConst, TFloat32.DTYPE)))))); + tf.math.mul(decayConst, tf.dtypes.cast(nextStepConst, TFloat32.class)))))); Operand mScheduleNew = tf.math.mul(momentum, mT); @@ -222,57 +221,57 @@ protected Optional prepare(String scopeName) { oneMinusMScheduleNew = tf.math.sub(one, mScheduleNew); oneMinusMScheduleNext = tf.math.sub(one, mScheduleNext); vTPrimeDenominator = - tf.math.sub(one, tf.math.pow(betaTwoConst, tf.dtypes.cast(localStepConst, TFloat32.DTYPE))); + tf.math.sub(one, tf.math.pow(betaTwoConst, tf.dtypes.cast(localStepConst, TFloat32.class))); return Optional.empty(); } /** {@inheritDoc} */ @Override protected Op applyDense(Output gradient, Output variable) { - DataType dType = gradient.dataType(); + Class type = gradient.type(); Variable m = getSlot(variable, FIRST_MOMENT).get(); // first Moment Variable v = getSlot(variable, SECOND_MOMENT).get(); // Second Moment // gPrime = grad / coefficients['oneMinusMScheduleNew'] - Operand gPrime = tf.math.div(gradient, tf.dtypes.cast(oneMinusMScheduleNew, dType)); + Operand gPrime = tf.math.div(gradient, tf.dtypes.cast(oneMinusMScheduleNew, type)); // mT = (coefficients['beta_1_t'] * m + coefficients['one_minus_beta_1_t'] * grad) Operand mT = tf.math.add( - tf.math.mul(tf.dtypes.cast(betaOneConst, dType), m), - tf.math.mul(tf.dtypes.cast(oneMinusBeta1, dType), gradient)); + tf.math.mul(tf.dtypes.cast(betaOneConst, type), m), + tf.math.mul(tf.dtypes.cast(oneMinusBeta1, type), gradient)); // mT = state_ops.assign(m, mT, use_locking=self._use_locking) // update m mT = tf.assign(m, mT, Assign.useLocking(true)); // mTPrime = mT / coefficients['oneMinusMScheduleNext'] - Operand mTPrime = tf.math.div(mT, tf.dtypes.cast(oneMinusMScheduleNext, dType)); + Operand mTPrime = tf.math.div(mT, tf.dtypes.cast(oneMinusMScheduleNext, type)); // vT = (coefficients['beta_2_t'] * v + coefficients['one_minus_beta_2_t'] * // math_ops.square(grad)) Operand vT = tf.math.add( - tf.math.mul(tf.dtypes.cast(betaTwoConst, dType), v), - tf.math.mul(tf.dtypes.cast(oneMinusBeta2, dType), tf.math.square(gradient))); + tf.math.mul(tf.dtypes.cast(betaTwoConst, type), v), + tf.math.mul(tf.dtypes.cast(oneMinusBeta2, type), tf.math.square(gradient))); // vT = state_ops.assign(v, vT, use_locking=self._use_locking) // update v vT = tf.assign(v, vT, Assign.useLocking(true)); // vTPrime = vT / coefficients['vTPrimeDenominator'] - Operand vTPrime = tf.math.div(vT, tf.dtypes.cast(vTPrimeDenominator, dType)); + Operand vTPrime = tf.math.div(vT, tf.dtypes.cast(vTPrimeDenominator, type)); // m_t_bar = (coefficients['oneMinusMT'] * gPrime + coefficients['mT1'] * mTPrime) Operand m_t_bar = tf.math.add( - tf.math.mul(tf.dtypes.cast(oneMinusMT, dType), gPrime), - tf.math.mul(tf.dtypes.cast(mT1, dType), mTPrime)); + tf.math.mul(tf.dtypes.cast(oneMinusMT, type), gPrime), + tf.math.mul(tf.dtypes.cast(mT1, type), mTPrime)); // varT = var - coefficients['lr_t'] * m_t_bar / (math_ops.sqrt(vTPrime) + // coefficients['epsilon']) Operand varT = tf.math.sub( variable, tf.math.div( - tf.math.mul(tf.dtypes.cast(learningRateConst, dType), m_t_bar), - tf.math.add(tf.math.sqrt(vTPrime), tf.dtypes.cast(epsilonConst, dType)))); + tf.math.mul(tf.dtypes.cast(learningRateConst, type), m_t_bar), + tf.math.add(tf.math.sqrt(vTPrime), tf.dtypes.cast(epsilonConst, type)))); return tf.assign(variable, varT, Assign.useLocking(true)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index 70f065814f7..fdf56da4a67 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -220,7 +220,7 @@ private Optional> getSlot(String varName, String s protected void createSlot( Output variable, String slotName, Operand initializer) { Variable slot = - tf.withName(createName(variable, slotName)).variable(variable.shape(), variable.dataType()); + tf.withName(createName(variable, slotName)).variable(variable.shape(), variable.type()); Assign slotInit = tf.assign(slot, initializer); graph.addInitializer(slotInit); String varName = variable.op().name(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index 9a48a9b8a7a..b3729dc367f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -175,14 +175,14 @@ protected void createSlots(List> variables) { */ private void createRMSPropSlot(Output v) { Operand rmsInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type())); createSlot(v.asOutput(), RMS, rmsInitializer); Operand momentumInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { Operand mgInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MG, mgInitializer); } } @@ -199,20 +199,20 @@ protected Op applyDense(Output gradient, Output variable mgSlot, rmsSlot, momentumSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), - tf.dtypes.cast(tf.constant(decay), gradient.dataType()), - tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), - tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), + tf.dtypes.cast(tf.constant(decay), gradient.type()), + tf.dtypes.cast(tf.constant(momentum), gradient.type()), + tf.dtypes.cast(tf.constant(epsilon), gradient.type()), gradient); } return tf.train.applyRmsProp( variable, rmsSlot, momentumSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), - tf.dtypes.cast(tf.constant(decay), gradient.dataType()), - tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), - tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), + tf.dtypes.cast(tf.constant(decay), gradient.type()), + tf.dtypes.cast(tf.constant(momentum), gradient.type()), + tf.dtypes.cast(tf.constant(epsilon), gradient.type()), gradient); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java index aec75e6078a..4e310ff56e3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.utils; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TType; @@ -35,8 +34,8 @@ public class CastHelper { */ @SuppressWarnings("unchecked") public static Operand cast( - Ops tf, Operand value, DataType requiredType) { - return (value.asOutput().dataType() == requiredType) + Ops tf, Operand value, Class requiredType) { + return (value.asOutput().type() == requiredType) ? (Operand) value : tf.dtypes.cast(value, requiredType); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java index 122de9f21ae..743e8114f6e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java @@ -17,10 +17,14 @@ import org.tensorflow.*; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.op.Scope; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.TUint8; +import org.tensorflow.types.family.TInteger; import org.tensorflow.types.family.TNumber; import java.util.ArrayList; @@ -36,7 +40,7 @@ public class ShapeUtils { * @param dims the Operand containing the shape values * @return a new Shape based on an Operand that contains dimensions */ - public static Shape toShape(Scope scope, Operand dims) { + public static Shape toShape(Scope scope, Operand dims) { long[] longDims = getLongArray(scope, dims); return Shape.of(longDims); } @@ -62,65 +66,37 @@ public static int[] getIntArray(Scope scope, Operand dims) { * @return the long array * @throws java.lang.IllegalArgumentException if the dims type is not an integer */ - public static long[] getLongArray(Scope scope, Operand dims) { - DataType dType = dims.asOutput().dataType(); - if (!dType.isInteger()) { - throw new IllegalArgumentException("the data type must be an integer type"); - } - List result = new ArrayList<>(); - + public static long[] getLongArray(Scope scope, Operand dims) { if (scope.env().isEager()) { - if (dType.equals(TInt32.DTYPE)) { - @SuppressWarnings("unchecked") - Operand idims = (Operand) dims; - - idims.asOutput().data().scalars().forEach(s -> result.add((long) s.getInt())); - } else if (dType.equals(TInt64.DTYPE)) { - @SuppressWarnings("unchecked") - Operand ldims = (Operand) dims; - ldims.asOutput().data().scalars().forEach(s -> result.add(s.getLong())); - } else if (dType.equals(TUint8.DTYPE)) { - @SuppressWarnings("unchecked") - Operand udims = (Operand) dims; - udims.asOutput().data().scalars().forEach(s -> result.add(s.getObject().longValue())); - } else { // shouldn't happen - throw new IllegalArgumentException("the data type must be an integer type"); - } - - } else { - try (Session session = new Session((Graph) scope.env())) { - if (dType.equals(TInt32.DTYPE)) { - try (Tensor tensorResult = - session.runner().fetch(dims).run().get(0).expect(TInt32.DTYPE)) { - tensorResult.data().scalars().forEach(s -> result.add((long) s.getInt())); - } - } else if (dType.equals(TInt64.DTYPE)) { - try (Tensor tensorResult = - session.runner().fetch(dims).run().get(0).expect(TInt64.DTYPE)) { - tensorResult.data().scalars().forEach(s -> result.add(s.getLong())); - } - } else if (dType.equals(TUint8.DTYPE)) { - try (Tensor tensorResult = - session.runner().fetch(dims).run().get(0).expect(TUint8.DTYPE)) { - tensorResult.data().scalars().forEach(s -> result.add(s.getObject().longValue())); - } - } else { // shouldn't happen - throw new IllegalArgumentException("the data type must be an integer type"); - } - } + return getLongArray(dims.asTensor()); + } + try (Session session = new Session((Graph)scope.env()); + T tensor = session.runner().fetch(dims).run().get(0)) { + return getLongArray(tensor); } - return result.stream().mapToLong(i -> i).toArray(); } /** - * Gets the shape for the data within a Tensor + * Converts a TInt32 or TInt64 to a java long array * - * @param tensor the tensor - * @return the Shape of the tensor's data; + * @param scope the TensorFlow scope + * @param dims the dimension tensor + * @param the type of the dimensions, must either be TInt32 or TInt64 type + * @return the long array + * @throws java.lang.IllegalArgumentException if the dims type is not an integer */ - public static Shape getShape(Tensor tensor) { - NdArray data = (NdArray) tensor.data(); - return data.shape(); + public static long[] getLongArray(T dims) { + List result = new ArrayList<>(); + if (dims instanceof TInt32) { + ((TInt32)dims).scalars().forEach(s -> result.add((long) s.getInt())); + } else if (dims instanceof TInt64) { + ((TInt64)dims).scalars().forEach(s -> result.add(s.getLong())); + } else if (dims instanceof TUint8) { + ((TUint8)dims).scalars().forEach(s -> result.add(s.getObject().longValue())); + } else { // shouldn't happen + throw new IllegalArgumentException("the data type must be an integer type"); + } + return result.stream().mapToLong(i -> i).toArray(); } /** diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java index f54401515ab..a0aa2c4b453 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java @@ -90,8 +90,8 @@ public void testCallFloat16() { Ops tf = session.getTF(); ReLU instance = new ReLU<>(tf); Operand result = - instance.call(tf.dtypes.cast(tf.constant(input), TFloat16.DTYPE)); - session.evaluate(tf.dtypes.cast(tf.constant(expected), TFloat16.DTYPE), result); + instance.call(tf.dtypes.cast(tf.constant(input), TFloat16.class)); + session.evaluate(tf.dtypes.cast(tf.constant(expected), TFloat16.class), result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java index 6a54cb08de6..41be23a8f6f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java @@ -17,7 +17,6 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; -import org.tensorflow.Tensor; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; @@ -40,19 +39,15 @@ public void testEagerBatchDataset() { Arrays.asList( tf.constant(testMatrix1), tf.constant(testMatrix2)), - Arrays.asList(TInt32.DTYPE, TInt32.DTYPE)) + Arrays.asList(TInt32.class, TInt32.class)) .batch(2); int count = 0; for (List> components : dataset) { - try (Tensor batch1 = - components.get(0).asTensor().expect(TInt32.DTYPE); - Tensor batch2 = - components.get(1).asTensor().expect(TInt32.DTYPE);) { - - assertEquals(testMatrix1.slice(range(count, count + 2)), batch1.data()); - assertEquals(testMatrix2.slice(range(count, count + 2)), batch2.data()); - + try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); + TInt32 batch2 = (TInt32)components.get(1).asTensor()) { + assertEquals(testMatrix1.slice(range(count, count + 2)), batch1); + assertEquals(testMatrix2.slice(range(count, count + 2)), batch2); count += 2; } } @@ -63,23 +58,16 @@ public void testDropLastBatch() { Ops tf = Ops.create(); Dataset dataset = Dataset .fromTensorSlices(tf, - Arrays.asList( - tf.constant(testMatrix1), - tf.constant(testMatrix2)), - Arrays.asList(TInt32.DTYPE, TInt32.DTYPE)) + Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), + Arrays.asList(TInt32.class, TInt32.class)) .batch(3, true); int count = 0; for (List> components : dataset) { - - try (Tensor batch1 = - components.get(0).asTensor().expect(TInt32.DTYPE); - Tensor batch2 = - components.get(1).asTensor().expect(TInt32.DTYPE);) { - - assertEquals(testMatrix1.slice(range(count, count + 3)), batch1.data()); - assertEquals(testMatrix2.slice(range(count, count + 3)), batch2.data()); - + try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); + TInt32 batch2 = (TInt32)components.get(1).asTensor()) { + assertEquals(testMatrix1.slice(range(count, count + 3)), batch1); + assertEquals(testMatrix2.slice(range(count, count + 3)), batch2); count += 3; } } @@ -90,31 +78,23 @@ public void testKeepLastBatch() { Ops tf = Ops.create(); Dataset dataset = Dataset .fromTensorSlices(tf, - Arrays.asList( - tf.constant(testMatrix1), - tf.constant(testMatrix2)), - Arrays.asList(TInt32.DTYPE, TInt32.DTYPE)) + Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), + Arrays.asList(TInt32.class, TInt32.class)) .batch(3, false); int count = 0; boolean foundLastBatch = false; for (List> components : dataset) { - try (Tensor batch1 = - components.get(0).asTensor().expect(TInt32.DTYPE); - Tensor batch2 = - components.get(1).asTensor().expect(TInt32.DTYPE);) { + try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); + TInt32 batch2 = (TInt32)components.get(1).asTensor();) { if (count == 0) { - assertEquals(testMatrix1.slice(range(count, count + 3)), - batch1.data()); - assertEquals(testMatrix2.slice(range(count, count + 3)), - batch2.data()); + assertEquals(testMatrix1.slice(range(count, count + 3)), batch1); + assertEquals(testMatrix2.slice(range(count, count + 3)), batch2); count += 3; } else { - assertEquals(testMatrix1.slice(range(count, count + 1)), - batch1.data()); - assertEquals(testMatrix2.slice(range(count, count + 1)), - batch2.data()); + assertEquals(testMatrix1.slice(range(count, count + 1)), batch1); + assertEquals(testMatrix2.slice(range(count, count + 1)), batch2); foundLastBatch = true; } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java index 6bb6e21f330..c5cb3978b4b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java @@ -16,17 +16,17 @@ package org.tensorflow.framework.data; import org.junit.jupiter.api.Test; -import org.tensorflow.DataType; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; +import org.tensorflow.types.family.TType; import org.tensorflow.exceptions.TFOutOfRangeException; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; import java.util.Arrays; import java.util.List; +import org.tensorflow.util.TensorList; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -39,7 +39,7 @@ public void testGraphIteration() { List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); - List> dataTypes = Arrays.asList(TInt32.DTYPE, TInt32.DTYPE); + List> dataTypes = Arrays.asList(TInt32.class, TInt32.class); Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes); DatasetIterator iterator = dataset.makeOneShotIterator(); @@ -53,15 +53,12 @@ public void testGraphIteration() { int batches = 0; while (true) { - try { - List> outputs = session.runner().fetch(x).fetch(y).run(); - - try (Tensor xBatch = outputs.get(0).expect(TInt32.DTYPE); - Tensor yBatch = outputs.get(1).expect(TInt32.DTYPE)) { - assertEquals(testMatrix1.get(batches), xBatch.data()); - assertEquals(testMatrix2.get(batches), yBatch.data()); - batches++; - } + try (TensorList outputs = session.runner().fetch(x).fetch(y).run()) { + TInt32 xBatch = outputs.get(0); + TInt32 yBatch = outputs.get(1); + assertEquals(testMatrix1.get(batches), xBatch); + assertEquals(testMatrix2.get(batches), yBatch); + batches++; } catch (TFOutOfRangeException e) { break; } @@ -77,17 +74,15 @@ public void testEagerIteration() { List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); - List> dataTypes = Arrays.asList(TInt32.DTYPE, TInt32.DTYPE); + List> dataTypes = Arrays.asList(TInt32.class, TInt32.class); Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes); int count = 0; for (List> outputs : dataset) { - try (Tensor batch1 = outputs.get(0).asTensor().expect(TInt32.DTYPE); - Tensor batch2 = outputs.get(1).asTensor().expect(TInt32.DTYPE); ) { - - assertEquals(testMatrix1.get(count), batch1.data()); - assertEquals(testMatrix2.get(count), batch2.data()); - + try (TInt32 batch1 = (TInt32)outputs.get(0).asTensor(); + TInt32 batch2 = (TInt32)outputs.get(1).asTensor()) { + assertEquals(testMatrix1.get(count), batch1); + assertEquals(testMatrix2.get(count), batch2); count++; } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java index 5960442ff70..c6be6960089 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java @@ -17,11 +17,10 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.tensorflow.DataType; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; +import org.tensorflow.types.family.TType; import org.tensorflow.exceptions.TFOutOfRangeException; import org.tensorflow.op.Ops; import org.tensorflow.ndarray.IntNdArray; @@ -30,6 +29,7 @@ import java.util.Arrays; import java.util.List; +import org.tensorflow.util.TensorList; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -60,13 +60,13 @@ public void testGraphIteration() { List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); - List> dataTypes = Arrays.asList(TInt32.DTYPE, TInt32.DTYPE); + List> dataTypes = Arrays.asList(TInt32.class, TInt32.class); Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes) .mapAllComponents( component -> - tf.math.mul(component.asOutput().expect(TInt32.DTYPE), tf.constant(2))); + tf.math.mul(component.asOutput().expect(TInt32.class), tf.constant(2))); DatasetIterator iterator = dataset.makeOneShotIterator(); List> components = iterator.getNext(); @@ -78,17 +78,12 @@ public void testGraphIteration() { int batches = 0; while (true) { - try { - List> outputs = session.runner().fetch(X).fetch(y).run(); - - try (Tensor XBatch = outputs.get(0).expect(TInt32.DTYPE); - Tensor yBatch = outputs.get(1).expect(TInt32.DTYPE)) { - - assertEquals(mapped1.get(batches), XBatch.data()); - assertEquals(mapped2.get(batches), yBatch.data()); - + try (TensorList outputs = session.runner().fetch(X).fetch(y).run()) { + TInt32 xBatch = outputs.get(0); + TInt32 yBatch = outputs.get(1); + assertEquals(mapped1.get(batches), xBatch); + assertEquals(mapped2.get(batches), yBatch); batches++; - } } catch (TFOutOfRangeException e) { break; } @@ -105,21 +100,18 @@ public void testEagerIteration() { List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); - List> dataTypes = Arrays.asList(TInt32.DTYPE, TInt32.DTYPE); + List> dataTypes = Arrays.asList(TInt32.class, TInt32.class); Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes) - .mapAllComponents( - op -> tf.math.mul(op.asOutput().expect(TInt32.DTYPE), tf.constant(2))); + .mapAllComponents(op -> tf.math.mul(op.expect(TInt32.class), tf.constant(2))); int count = 0; for (List> outputs : dataset) { - try (Tensor XBatch = outputs.get(0).asTensor().expect(TInt32.DTYPE); - Tensor yBatch = outputs.get(1).asTensor().expect(TInt32.DTYPE); ) { - - assertEquals(mapped1.get(count), XBatch.data()); - assertEquals(mapped2.get(count), yBatch.data()); - + try (TInt32 XBatch = (TInt32)outputs.get(0).asTensor(); + TInt32 yBatch = (TInt32)outputs.get(1).asTensor()) { + assertEquals(mapped1.get(count), XBatch); + assertEquals(mapped2.get(count), yBatch); count++; } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java index 9ff8080034d..d0cdb4527a5 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java @@ -17,7 +17,6 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; -import org.tensorflow.Tensor; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; @@ -35,16 +34,15 @@ public void testEagerSkipDataset() { Dataset.fromTensorSlices( tf, Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), - Arrays.asList(TInt32.DTYPE, TInt32.DTYPE)) + Arrays.asList(TInt32.class, TInt32.class)) .skip(2); int count = 2; for (List> components : dataset) { - try (Tensor batch1 = components.get(0).asTensor().expect(TInt32.DTYPE); - Tensor batch2 = - components.get(1).asTensor().expect(TInt32.DTYPE); ) { - assertEquals(testMatrix1.get(count), batch1.data()); - assertEquals(testMatrix2.get(count), batch2.data()); + try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); + TInt32 batch2 = (TInt32)components.get(1).asTensor()) { + assertEquals(testMatrix1.get(count), batch1); + assertEquals(testMatrix2.get(count), batch2); count++; } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java index 4419f4660db..79a2e79c72e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java @@ -17,7 +17,6 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; -import org.tensorflow.Tensor; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; @@ -36,16 +35,15 @@ public void testEagerTakeDataset() { Dataset.fromTensorSlices( tf, Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), - Arrays.asList(TInt32.DTYPE, TInt32.DTYPE)) + Arrays.asList(TInt32.class, TInt32.class)) .take(4); int count = 0; for (List> components : dataset) { - try (Tensor batch1 = components.get(0).asTensor().expect(TInt32.DTYPE); - Tensor batch2 = components.get(1).asTensor().expect(TInt32.DTYPE); ) { - - assertEquals(testMatrix1.get(count), batch1.data()); - assertEquals(testMatrix2.get(count), batch2.data()); + try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); + TInt32 batch2 = (TInt32)components.get(1).asTensor()) { + assertEquals(testMatrix1.get(count), batch1); + assertEquals(testMatrix2.get(count), batch2); count++; } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java index 46e4232d5ae..4e81e0620e6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java @@ -52,7 +52,7 @@ public void testCallUInt() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 0xf); - Operand operand = instance.call(tf.constant(shape), TUint8.DTYPE); + Operand operand = instance.call(tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -68,7 +68,7 @@ public void testCallInt() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 0xf); - Operand operand = instance.call(tf.constant(shape), TInt32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -84,7 +84,7 @@ public void testCallLong() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 0xffL); - Operand operand = instance.call(tf.constant(shape), TInt64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -98,7 +98,7 @@ public void testCallFloat() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 12.F); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -113,7 +113,7 @@ public void testCallDouble() { Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 11.); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -130,7 +130,7 @@ public void testCallString() { Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 22); - instance.call(tf.constant(shape), TString.DTYPE); + instance.call(tf.constant(shape), TString.class); fail("IllegalArgumentException should have been thrown for TString"); } }); @@ -146,7 +146,7 @@ public void testCallBool() { Boolean[] expected = {true, true, true, true}; Constant instance = new Constant<>(tf, true); - Operand operand = instance.call(tf.constant(shape), TBool.DTYPE); + Operand operand = instance.call(tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -159,8 +159,8 @@ public void testReproducible() { Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 11.); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java index a68bf2a0a98..e9769806928 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java @@ -51,9 +51,9 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,8 +68,8 @@ public void testCallNormalDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,8 +82,8 @@ public void testCallUniformFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -97,8 +97,8 @@ public void testCallUniformDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -109,9 +109,9 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -122,9 +122,9 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -135,10 +135,10 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = + Glorot instance = new Glorot<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java index 468759d347f..8953fa3005e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java @@ -51,8 +51,8 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -66,8 +66,8 @@ public void testCallNormalDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -80,8 +80,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + He instance = new He<>(tf, Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -95,8 +95,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + He instance = new He<>(tf, Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -107,9 +107,9 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -120,9 +120,9 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + He instance = new He<>(tf, Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -133,9 +133,9 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + He instance = new He<>(tf, Distribution.NORMAL, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java index adb6c0c118a..6eee5473937 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java @@ -45,23 +45,6 @@ public void setUp() {} @AfterEach public void tearDown() {} - /** Test of call method, of class Orthogonal. */ - @Test - public void testCallInt() { - for (TestSession.Mode tfMode : tfModes) - assertThrows( - java.lang.IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Shape shape = Shape.of(10, 10); - Identity instance = new Identity<>(tf, 2.); - instance.call(tf.constant(shape), TInt32.DTYPE); - fail("Should have thrown IllegalArgumentException on Integer type"); - } - }); - } - /** Test of call method, of class Constant. */ @Test public void testCallFloat() { @@ -82,7 +65,7 @@ public void testCallFloat() { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); Identity instance = new Identity<>(tf, 2.); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -108,7 +91,7 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); Identity instance = new Identity<>(tf, 2.); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -121,8 +104,8 @@ public void testReproducible() { Shape shape = Shape.of(2, 2); Identity instance = new Identity<>(tf, 2.); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java index 6033f9e12a5..336850a5549 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java @@ -51,8 +51,8 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -66,8 +66,8 @@ public void testCallNormalDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -80,8 +80,8 @@ public void testCallUniformFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -95,8 +95,8 @@ public void testCallUniformDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -107,9 +107,9 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -120,9 +120,9 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -133,9 +133,9 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.NORMAL, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java index bbd2ba3d384..053ba5dd7ff 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java @@ -52,7 +52,7 @@ public void testCallUInt() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TUint8.DTYPE); + Operand operand = instance.call(tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -66,7 +66,7 @@ public void testCallInt() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -80,7 +80,7 @@ public void testCallLong() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -94,7 +94,7 @@ public void testCallFloat() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -109,7 +109,7 @@ public void testCallDouble() { Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -126,7 +126,7 @@ public void testCallString() { Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - instance.call(tf.constant(shape), TString.DTYPE); + instance.call(tf.constant(shape), TString.class); fail("IllegalArgumentException should have been thrown for TString"); } }); @@ -141,7 +141,7 @@ public void testCallBool() { Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TBool.DTYPE); + Operand operand = instance.call(tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -154,8 +154,8 @@ public void testReproducible() { Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java index a4fff5fd19c..22b89d9177c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java @@ -47,23 +47,6 @@ public void setUp() {} @AfterEach public void tearDown() {} - /** Test of call method, of class Orthogonal. */ - @Test - public void testCallInt() { - for (TestSession.Mode tfMode : tfModes) - assertThrows( - java.lang.IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - instance.call(tf.constant(shape), TInt32.DTYPE); - fail("Should have thrown IllegalArgumentException on Integer type"); - } - }); - } - /** Test of call method, of class Orthogonal. */ @Test public void testCallFloat() { @@ -173,8 +156,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -288,8 +271,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -301,9 +284,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java index 50aec670503..3b2b3bdb243 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java @@ -52,9 +52,9 @@ public void testCalltestSoftmaxFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = + RandomNormal instance = new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,9 +68,9 @@ public void testCalltestSoftmaxDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = + RandomNormal instance = new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,10 +82,10 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = + RandomNormal instance = new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java index d3f9af74209..23e26083a9b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java @@ -53,9 +53,9 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = + RandomUniform instance = new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TInt32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -68,9 +68,9 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = + RandomUniform instance = new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -84,9 +84,9 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = + RandomUniform instance = new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -98,10 +98,10 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = + RandomUniform instance = new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java index 0a551df2f38..96bf915e199 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java @@ -52,9 +52,9 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = + TruncatedNormal instance = new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,9 +68,9 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = + TruncatedNormal instance = new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,10 +82,10 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = + TruncatedNormal instance = new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java index 77e0dd7afc7..159affb07e2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java @@ -50,14 +50,14 @@ public void testCallFloat1FanInTruncatedNormal() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -71,14 +71,14 @@ public void testCallDouble1FanInTruncatedNormal() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -91,14 +91,14 @@ public void testCallFloat1FanInNormal() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -112,14 +112,14 @@ public void testCalltestSoftmaxDouble1FanInNormal() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -132,10 +132,10 @@ public void testCalltestSoftmaxFloat1FanInUNIFORM() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -149,10 +149,10 @@ public void testCalltestSoftmaxDouble1FanInUNIFORM() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -164,11 +164,11 @@ public void testReproducible1() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -180,15 +180,15 @@ public void testReproducible2() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -200,15 +200,15 @@ public void testReproducible3() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_OUT, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -220,11 +220,11 @@ public void testReproducible4() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_AVG, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java index 975678add19..21bad6ff360 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java @@ -49,7 +49,7 @@ public void testCallUInt() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TUint8.DTYPE); + Operand operand = instance.call(tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -63,7 +63,7 @@ public void testCallInt() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -77,7 +77,7 @@ public void testCallLong() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -91,7 +91,7 @@ public void testCallFloat() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -106,7 +106,7 @@ public void testCallDouble() { Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -120,7 +120,7 @@ public void testCallString() { Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TString.DTYPE); + Operand operand = instance.call(tf.constant(shape), TString.class); session.evaluateString(operand, String::isEmpty); } } @@ -135,7 +135,7 @@ public void testCallBool() { Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TBool.DTYPE); + Operand operand = instance.call(tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -148,8 +148,8 @@ public void testReproducible() { Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java index 5c4ce542c65..bf5e2b1450b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java @@ -89,8 +89,8 @@ public void testBasic() { float[] var1Init = {3.0F, 4.0F}; float[] fgrads = {grad, grad}; Shape shape = Shape.of(var0Init.length); - Variable var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java index ef9053ff1eb..d5b2657a4fc 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java @@ -64,8 +64,8 @@ public void testBasic() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java index c5ae178b84c..d2eb63056e8 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java @@ -79,8 +79,8 @@ public void testBasic() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java index 461fa75397f..4f2afa07972 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java @@ -16,7 +16,6 @@ import org.junit.jupiter.api.*; import org.tensorflow.Graph; -import org.tensorflow.Tensor; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -80,8 +79,8 @@ public void testBasic() { Ops tf = instance.getTF(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); @@ -140,25 +139,23 @@ public void testBasic() { (float) Math.pow(beta1, step + 1), (float) Math.pow(beta2, step + 1) }; - try (Tensor result = + try (TFloat32 result = session .getGraphSession() .runner() .fetch("beta1_power") .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); + .get(0)) { + result.scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); } - try (Tensor result = + try (TFloat32 result = session .getGraphSession() .runner() .fetch("beta2_power") .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); + .get(0)) { + result.scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); } session.run(update); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java index de17395f76a..ea0e96c7fad 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java @@ -16,7 +16,6 @@ import org.junit.jupiter.api.*; import org.tensorflow.Graph; -import org.tensorflow.Tensor; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -101,8 +100,8 @@ public void testBasic() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); @@ -149,15 +148,14 @@ public void testBasic() { // Test powers final float beta1Power = (float) Math.pow(BETA_ONE_DEFAULT, step + 1); - try (Tensor result = + try (TFloat32 result = session .getGraphSession() .runner() .fetch("beta1_power") .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1)); + .get(0)) { + result.scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1)); } session.run(update); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java index 597f8e52bcd..7698d76f957 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java @@ -76,8 +76,8 @@ public void testFtrlWithL1L2L2Shrinkage() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); @@ -141,8 +141,8 @@ public void testFtrlWithL1() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); @@ -207,8 +207,8 @@ public void testFtrlWithL1L2() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); @@ -273,8 +273,8 @@ public void doTestFtrlwithoutRegularization() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index 4362c54d815..aefcc537979 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -61,8 +61,8 @@ public void testBasic() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java index bcfff97773d..ca72c4f415c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java @@ -77,8 +77,8 @@ public void testBasic() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); @@ -130,8 +130,8 @@ public void testMomentum() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index a583d74246b..4df0994dffc 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -16,7 +16,6 @@ import org.junit.jupiter.api.*; import org.tensorflow.Graph; -import org.tensorflow.Tensor; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -102,8 +101,8 @@ public void testBasic() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); @@ -147,15 +146,14 @@ public void testBasic() { session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); - try (Tensor result = + try (TFloat32 result = session .getGraphSession() .runner() .fetch("momentum") .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); + .get(0)) { + result.scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); } momentum = 1F; @@ -167,15 +165,14 @@ public void testBasic() { Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); momentum = momentum * mut; - try (Tensor result = + try (TFloat32 result = session .getGraphSession() .runner() .fetch("momentum") .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); + .get(0)) { + result.scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); } mcache = ND.mul(mcache, momentum); FloatNdArray[] resultsNP = nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java index 202fb21ef68..3b002cd1dbe 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java @@ -87,8 +87,8 @@ public void testDense() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java index bca90211e50..f9033c49cc8 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java @@ -84,57 +84,57 @@ public EagerSession getEagerSession() { /** {@inheritDoc} */ @Override public void evaluate(double expected, Operand input) { - DataType dtype = input.asOutput().dataType(); - if (dtype == TFloat32.DTYPE) { + Class type = input.asOutput().type(); + if (type == TFloat32.class) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } index.set(0); - o.data().scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon)); - } else if (dtype == TFloat64.DTYPE) { + o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon)); + } else if (type == TFloat64.class) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } index.set(0); - o.data().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); - } else if (dtype == TInt32.DTYPE) { + o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); + } else if (type == TInt32.class) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } index.set(0); - o.data().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); - } else if (dtype == TInt64.DTYPE) { + o.asTensor().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); + } else if (type == TInt64.class) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } index.set(0); - o.data().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); - } else if (dtype == TUint8.DTYPE) { + o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); + } else if (type == TUint8.class) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); } index.set(0); - o.data().scalars().forEach(f -> assertEquals((long) expected, f.getByte())); + o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } } @@ -146,71 +146,71 @@ public void evaluate(Number[] expected, Output input) { expected.length, size, () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - DataType dtype = input.dataType(); - if (dtype == TFloat32.DTYPE) { + Class type = input.type(); + if (type == TFloat32.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach( f -> assertEquals( expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); - } else if (dtype == TFloat64.DTYPE) { + } else if (type == TFloat64.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach( f -> assertEquals( expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); - } else if (dtype == TInt32.DTYPE) { + } else if (type == TInt32.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); - } else if (dtype == TInt64.DTYPE) { + } else if (type == TInt64.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); - } else if (dtype == TUint8.DTYPE) { + } else if (type == TUint8.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%x). %d\n", index.getAndIncrement(), f.getByte())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].byteValue(), f.getByte())); } @@ -219,69 +219,69 @@ public void evaluate(Number[] expected, Output input) { /** {@inheritDoc} */ @Override public void evaluate(FloatNdArray expected, Output input) { - DataType dtype = input.dataType(); - if (dtype == TFloat32.DTYPE) { + Class type = input.type(); + if (type == TFloat32.class) { Output o = (Output) input; AtomicLong index = new AtomicLong(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach( f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); - } else if (dtype == TFloat64.DTYPE) { + } else if (type == TFloat64.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach( f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); - } else if (dtype == TInt32.DTYPE) { + } else if (type == TInt32.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } index.set(0); - for (IntNdArray f : o.data().scalars()) { + for (IntNdArray f : o.asTensor().scalars()) { assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt()); } - } else if (dtype == TInt64.DTYPE) { + } else if (type == TInt64.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach( f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); - } else if (dtype == TUint8.DTYPE) { + } else if (type == TUint8.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach( f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); @@ -296,10 +296,10 @@ public void evaluateString(Output input, Predicate predicate) { if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %s\n", predicate.test(input.data().getObject()), input.data().getObject()); + "0). %b <==> %s\n", predicate.test(input.asTensor().getObject()), input.asTensor().getObject()); } else { input - .data() + .asTensor() .scalars() .forEachIndexed( (idx, s) -> @@ -310,9 +310,9 @@ public void evaluateString(Output input, Predicate predicate) { } index.set(0); if (isScalar) { - assertTrue(predicate.test(input.data().getObject())); + assertTrue(predicate.test(input.asTensor().getObject())); } else { - input.data().scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); + input.asTensor().scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); } } @@ -320,16 +320,16 @@ public void evaluateString(Output input, Predicate predicate) { @Override public void evaluate(Output input, Predicate predicate) { AtomicInteger index = new AtomicInteger(); - DataType dtype = input.asOutput().dataType(); + Class type = input.asOutput().type(); boolean isScalar = input.shape().equals(Shape.scalar()); - if (dtype == TFloat32.DTYPE) { + if (type == TFloat32.class) { Output o = (Output) input; if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %f\n", predicate.test(o.data().getFloat()), o.data().getFloat()); + "0). %b <==> %f\n", predicate.test(o.asTensor().getFloat()), o.asTensor().getFloat()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> @@ -340,20 +340,20 @@ public void evaluate(Output input, Predicate predic } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.data().getFloat())); + assertTrue(predicate.test(o.asTensor().getFloat())); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getFloat()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); } - } else if (dtype == TFloat64.DTYPE) { + } else if (type == TFloat64.class) { Output o = (Output) input; if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %f\n", predicate.test(o.data().getDouble()), o.data().getDouble()); + "0). %b <==> %f\n", predicate.test(o.asTensor().getDouble()), o.asTensor().getDouble()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> @@ -364,20 +364,20 @@ public void evaluate(Output input, Predicate predic } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.data().getDouble())); + assertTrue(predicate.test(o.asTensor().getDouble())); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getDouble()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getDouble()))); } - } else if (dtype == TFloat16.DTYPE) { + } else if (type == TFloat16.class) { Output o = (Output) input; if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %f\n", predicate.test(o.data().getFloat()), o.data().getFloat()); + "0). %b <==> %f\n", predicate.test(o.asTensor().getFloat()), o.asTensor().getFloat()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> @@ -388,20 +388,20 @@ public void evaluate(Output input, Predicate predic } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.data().getFloat())); + assertTrue(predicate.test(o.asTensor().getFloat())); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getFloat()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); } - } else if (dtype == TInt32.DTYPE) { + } else if (type == TInt32.class) { Output o = (Output) input; if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %d\n", predicate.test(o.data().getInt()), o.data().getInt()); + "0). %b <==> %d\n", predicate.test(o.asTensor().getInt()), o.asTensor().getInt()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> @@ -412,20 +412,20 @@ public void evaluate(Output input, Predicate predic } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.data().getInt())); + assertTrue(predicate.test(o.asTensor().getInt())); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getInt()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getInt()))); } - } else if (dtype == TInt64.DTYPE) { + } else if (type == TInt64.class) { Output o = (Output) input; if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %d\n", predicate.test(o.data().getLong()), o.data().getLong()); + "0). %b <==> %d\n", predicate.test(o.asTensor().getLong()), o.asTensor().getLong()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> @@ -436,20 +436,20 @@ public void evaluate(Output input, Predicate predic } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.data().getLong())); + assertTrue(predicate.test(o.asTensor().getLong())); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getLong()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getLong()))); } - } else if (dtype == TUint8.DTYPE) { + } else if (type == TUint8.class) { Output o = (Output) input; if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %x\n", predicate.test(o.data().getByte()), o.data().getByte()); + "0). %b <==> %x\n", predicate.test(o.asTensor().getByte()), o.asTensor().getByte()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> @@ -460,14 +460,14 @@ public void evaluate(Output input, Predicate predic } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.data().getByte())); + assertTrue(predicate.test(o.asTensor().getByte())); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getByte()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getByte()))); } } else { - fail("Unexpected DataType: " + dtype); + fail("Unexpected Class: " + type); } } @@ -482,13 +482,13 @@ public void evaluate(String[] expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { input - .data() + .asTensor() .scalars() .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); } index.set(0); input - .data() + .asTensor() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } @@ -504,13 +504,13 @@ public void evaluate(Boolean[] expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { input - .data() + .asTensor() .scalars() .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); } index.set(0); input - .data() + .asTensor() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getBoolean())); } @@ -522,184 +522,184 @@ public void evaluate(Output expected, Output input) { : String.format( "expected shape (%s) != to input shape (%s)", expected.shape().toString(), input.shape().toString()); - DataType dtype = input.asOutput().dataType(); + Class type = input.asOutput().type(); boolean isScalar = input.shape().equals(Shape.scalar()); - if (dtype == TFloat32.DTYPE) { + if (type == TFloat32.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %f <==> %f\n", x.data().getFloat(), o.data().getFloat()); + System.out.printf("0). %f <==> %f\n", x.asTensor().getFloat(), o.asTensor().getFloat()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %f <==> %f\n", - index.getAndIncrement(), x.data().getFloat(idx), f.getFloat())); + index.getAndIncrement(), x.asTensor().getFloat(idx), f.getFloat())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getFloat(), o.data().getFloat(), epsilon); + assertEquals(x.asTensor().getFloat(), o.asTensor().getFloat(), epsilon); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( - (idx, f) -> assertEquals(x.data().getFloat(idx), f.getFloat(), epsilon)); + (idx, f) -> assertEquals(x.asTensor().getFloat(idx), f.getFloat(), epsilon)); } - } else if (dtype == TFloat64.DTYPE) { + } else if (type == TFloat64.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %f <==> %f\n", x.data().getDouble(), o.data().getDouble()); + System.out.printf("0). %f <==> %f\n", x.asTensor().getDouble(), o.asTensor().getDouble()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %f <==> %f\n", - index.getAndIncrement(), x.data().getDouble(idx), f.getDouble())); + index.getAndIncrement(), x.asTensor().getDouble(idx), f.getDouble())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getDouble(), o.data().getDouble(), epsilon); + assertEquals(x.asTensor().getDouble(), o.asTensor().getDouble(), epsilon); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( - (idx, f) -> assertEquals(x.data().getDouble(idx), f.getDouble(), epsilon)); + (idx, f) -> assertEquals(x.asTensor().getDouble(idx), f.getDouble(), epsilon)); } - } else if (dtype == TInt32.DTYPE) { + } else if (type == TInt32.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %d <==> %d\n", x.data().getInt(), o.data().getInt()); + System.out.printf("0). %d <==> %d\n", x.asTensor().getInt(), o.asTensor().getInt()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %d <==> %d\n", - index.getAndIncrement(), x.data().getInt(idx), f.getInt())); + index.getAndIncrement(), x.asTensor().getInt(idx), f.getInt())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getInt(), o.data().getInt()); + assertEquals(x.asTensor().getInt(), o.asTensor().getInt()); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.data().getInt(idx), f.getInt())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getInt(idx), f.getInt())); } - } else if (dtype == TInt64.DTYPE) { + } else if (type == TInt64.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %d <==> %d\n", x.data().getLong(), o.data().getLong()); + System.out.printf("0). %d <==> %d\n", x.asTensor().getLong(), o.asTensor().getLong()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %d <==> %d\n", - index.getAndIncrement(), x.data().getLong(idx), f.getLong())); + index.getAndIncrement(), x.asTensor().getLong(idx), f.getLong())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getLong(), o.data().getLong()); + assertEquals(x.asTensor().getLong(), o.asTensor().getLong()); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.data().getLong(idx), f.getLong())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getLong(idx), f.getLong())); } - } else if (dtype == TUint8.DTYPE) { + } else if (type == TUint8.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %x <==> %x\n", x.data().getByte(), o.data().getByte()); + System.out.printf("0). %x <==> %x\n", x.asTensor().getByte(), o.asTensor().getByte()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %x <==> %x\n", - index.getAndIncrement(), x.data().getByte(idx), f.getByte())); + index.getAndIncrement(), x.asTensor().getByte(idx), f.getByte())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getByte(), o.data().getByte()); + assertEquals(x.asTensor().getByte(), o.asTensor().getByte()); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.data().getByte(idx), f.getByte())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getByte(idx), f.getByte())); } - } else if (dtype == TString.DTYPE) { + } else if (type == TString.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %s <==> %s\n", x.data().getObject(), o.data().getObject()); + System.out.printf("0). %s <==> %s\n", x.asTensor().getObject(), o.asTensor().getObject()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %s <==> %s\n", - index.getAndIncrement(), x.data().getObject(idx), f.getObject())); + index.getAndIncrement(), x.asTensor().getObject(idx), f.getObject())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getObject(), o.data().getObject()); + assertEquals(x.asTensor().getObject(), o.asTensor().getObject()); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.data().getObject(idx), f.getObject())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getObject(idx), f.getObject())); } - } else if (dtype == TBool.DTYPE) { + } else if (type == TBool.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %b <==> %b\n", x.data().getBoolean(), o.data().getBoolean()); + System.out.printf("0). %b <==> %b\n", x.asTensor().getBoolean(), o.asTensor().getBoolean()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %b <==> %b\n", - index.getAndIncrement(), x.data().getBoolean(idx), f.getBoolean())); + index.getAndIncrement(), x.asTensor().getBoolean(idx), f.getBoolean())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getBoolean(), o.data().getBoolean()); + assertEquals(x.asTensor().getBoolean(), o.asTensor().getBoolean()); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.data().getBoolean(idx), f.getBoolean())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getBoolean(idx), f.getBoolean())); } } } @@ -707,51 +707,51 @@ public void evaluate(Output expected, Output input) { /** {@inheritDoc} */ @Override public void print(PrintWriter writer, Output input) { - DataType dtype = input.asOutput().dataType(); - if (dtype == TFloat32.DTYPE) { + Class type = input.asOutput().type(); + if (type == TFloat32.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } else if (dtype == TFloat64.DTYPE) { + } else if (type == TFloat64.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } else if (dtype == TInt32.DTYPE) { + } else if (type == TInt32.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } else if (dtype == TInt64.DTYPE) { + } else if (type == TInt64.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } else if (dtype == TUint8.DTYPE) { + } else if (type == TUint8.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); - } else if (dtype == TString.DTYPE) { + } else if (type == TString.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); - } else if (dtype == TBool.DTYPE) { + } else if (type == TBool.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); } else { - writer.println("Unexpected DataType: " + dtype); + writer.println("Unexpected Class: " + type); } writer.flush(); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 33ddec6dce3..d2e856e9f22 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -104,89 +104,84 @@ public void run(Op op) { /** {@inheritDoc} */ @Override public void evaluate(double expected, Operand input) { - DataType dtype = input.asOutput().dataType(); - if (dtype == TFloat32.DTYPE) { + Class dtype = input.asOutput().type(); + if (dtype == TFloat32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); + try (TFloat32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { + result.scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); } - } else if (dtype == TFloat64.DTYPE) { + } else if (dtype == TFloat64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); + try (TFloat64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { + result.scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); } - } else if (dtype == TInt32.DTYPE) { + } else if (dtype == TInt32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); + try (TInt32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { + result.scalars().forEach(f -> assertEquals((int) expected, f.getInt())); } - } else if (dtype == TInt64.DTYPE) { + } else if (dtype == TInt64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); + try (TInt64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { + result.scalars().forEach(f -> assertEquals((long) expected, f.getLong())); } - } else if (dtype == TUint8.DTYPE) { + } else if (dtype == TUint8.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals((long) expected, f.getByte())); + try (TUint8 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { + result.scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } } else { - fail("Unexpected DataType: " + dtype); + fail("Unexpected Class: " + dtype); } } @@ -198,223 +193,203 @@ public void evaluate(Number[] expected, Output input) { expected.length, size, () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - DataType dtype = input.asOutput().dataType(); - if (dtype == TFloat32.DTYPE) { + Class dtype = input.asOutput().type(); + if (dtype == TFloat32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> assertEquals( expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); } - } else if (dtype == TFloat64.DTYPE) { + } else if (dtype == TFloat64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> assertEquals( expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); } - } else if (dtype == TInt32.DTYPE) { + } else if (dtype == TInt32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); } - } else if (dtype == TInt64.DTYPE) { + } else if (dtype == TInt64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); } - } else if (dtype == TUint8.DTYPE) { + } else if (dtype == TUint8.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getByte())); } } else { - fail("Unexpected DataType: " + dtype); + fail("Unexpected Class: " + dtype); } } /** {@inheritDoc} */ @Override public void evaluate(FloatNdArray expected, Output input) { - DataType dtype = input.asOutput().dataType(); - if (dtype == TFloat32.DTYPE) { + Class dtype = input.asOutput().type(); + if (dtype == TFloat32.class) { AtomicLong index = new AtomicLong(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> assertEquals( expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); } - } else if (dtype == TFloat64.DTYPE) { + } else if (dtype == TFloat64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> assertEquals( expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); } - } else if (dtype == TInt32.DTYPE) { + } else if (dtype == TInt32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt())); } - } else if (dtype == TInt64.DTYPE) { + } else if (dtype == TInt64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); } - } else if (dtype == TUint8.DTYPE) { + } else if (dtype == TUint8.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); } } else { - fail("Unexpected DataType: " + dtype); + fail("Unexpected Class: " + dtype); } } @@ -428,19 +403,17 @@ public void evaluate(String[] expected, Output input) { () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } @@ -456,19 +429,17 @@ public void evaluate(Boolean[] expected, Output input) { () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + try (TBool result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getObject())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + try (TBool result = + this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } @@ -482,331 +453,313 @@ public void evaluate(Output expected, Output input) { "expected shape (%s) != to input shape (%s)", expected.shape().toString(), input.shape().toString()); AtomicInteger index = new AtomicInteger(); - DataType dtype = input.asOutput().dataType(); - if (!dtype.equals(expected.dataType())) { + Class dtype = input.asOutput().type(); + if (!dtype.equals(expected.type())) { throw new IllegalArgumentException( String.format( "Both data type must be equal, inout = %s, expected = %s", - dtype, expected.dataType())); + dtype, expected.type())); } boolean isScalar = input.shape().equals(Shape.scalar()); - if (dtype == TFloat32.DTYPE) { + if (dtype == TFloat32.class) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + this.getGraphSession().runner().fetch(input).run().get(0); + TFloat32 expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %f <==> %f\n", expectedResult.data().getFloat(), result.data().getFloat()); + "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %f <==> %f\n", index.getAndIncrement(), - finalExpected.data().getFloat(idx), + finalExpected.asTensor().getFloat(idx), f.getFloat())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + this.getGraphSession().runner().fetch(input).run().get(0); + TFloat32 expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getFloat(), result.data().getFloat(), epsilon); + assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> - assertEquals(expectedResult.data().getFloat(idx), f.getFloat(), epsilon)); + assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); } } - } else if (dtype == TFloat64.DTYPE) { + } else if (dtype == TFloat64.class) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + this.getGraphSession().runner().fetch(input).run().get(0); + TFloat64 expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %f <==> %f\n", expectedResult.data().getDouble(), result.data().getDouble()); + "0). %f <==> %f\n", expectedResult.getDouble(), result.getDouble()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %f <==> %f\n", index.getAndIncrement(), - finalExpected.data().getDouble(idx), + finalExpected.asTensor().getDouble(idx), f.getDouble())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + this.getGraphSession().runner().fetch(input).run().get(0); + TFloat64 expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getDouble(), result.data().getDouble(), epsilon); + assertEquals(expectedResult.getDouble(), result.getDouble(), epsilon); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> - assertEquals(expectedResult.data().getDouble(idx), f.getDouble(), epsilon)); + assertEquals(expectedResult.getDouble(idx), f.getDouble(), epsilon)); } } - } else if (dtype == TFloat16.DTYPE) { + } else if (dtype == TFloat16.class) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat16.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat16.DTYPE)) { + try (TFloat16 result = + this.getGraphSession().runner().fetch(input).run().get(0); + TFloat16 expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %f <==> %f\n", expectedResult.data().getFloat(), result.data().getFloat()); + "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); } else { result - .data() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.data().getFloat(idx), - f.getFloat())); + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), + finalExpected.asTensor().getFloat(idx), + f.getFloat())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat16.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat16.DTYPE)) { + try (TFloat16 result = + this.getGraphSession().runner().fetch(input).run().get(0); + TFloat16 expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getFloat(), result.data().getFloat(), epsilon); + assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); } else { result - .data() - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.data().getFloat(idx), f.getFloat(), epsilon)); + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); } } - } else if (dtype == TInt32.DTYPE) { + } else if (dtype == TInt32.class) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + this.getGraphSession().runner().fetch(input).run().get(0); + TInt32 expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %d <==> %d\n", expectedResult.data().getInt(), result.data().getInt()); + "0). %d <==> %d\n", expectedResult.getInt(), result.getInt()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %d <==> %d\n", - index.getAndIncrement(), finalExpected.data().getInt(idx), f.getInt())); + index.getAndIncrement(), finalExpected.asTensor().getInt(idx), f.getInt())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + this.getGraphSession().runner().fetch(input).run().get(0); + TInt32 expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getInt(), result.data().getInt(), epsilon); + assertEquals(expectedResult.getInt(), result.getInt(), epsilon); } else { result - .data() .scalars() .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.data().getInt(idx), f.getInt(), epsilon)); + (idx, f) -> assertEquals(expectedResult.getInt(idx), f.getInt(), epsilon)); } } - } else if (dtype == TInt64.DTYPE) { + } else if (dtype == TInt64.class) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + this.getGraphSession().runner().fetch(input).run().get(0); + TInt64 expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %d <==> %d\n", expectedResult.data().getLong(), result.data().getLong()); + "0). %d <==> %d\n", expectedResult.getLong(), result.getLong()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %d <==> %d\n", index.getAndIncrement(), - finalExpected.data().getLong(idx), + finalExpected.asTensor().getLong(idx), f.getLong())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + this.getGraphSession().runner().fetch(input).run().get(0); + TInt64 expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getLong(), result.data().getLong(), epsilon); + assertEquals(expectedResult.getLong(), result.getLong(), epsilon); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> - assertEquals(expectedResult.data().getLong(idx), f.getLong(), epsilon)); + assertEquals(expectedResult.getLong(idx), f.getLong(), epsilon)); } } - } else if (dtype == TUint8.DTYPE) { + } else if (dtype == TUint8.class) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + this.getGraphSession().runner().fetch(input).run().get(0); + TUint8 expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %d <==> %d\n", expectedResult.data().getByte(), result.data().getByte()); + "0). %d <==> %d\n", expectedResult.getByte(), result.getByte()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %d <==> %d\n", index.getAndIncrement(), - finalExpected.data().getByte(idx), + finalExpected.asTensor().getByte(idx), f.getByte())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + this.getGraphSession().runner().fetch(input).run().get(0); + TUint8 expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getByte(), result.data().getByte(), epsilon); + assertEquals(expectedResult.getByte(), result.getByte(), epsilon); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> - assertEquals(expectedResult.data().getByte(idx), f.getByte(), epsilon)); + assertEquals(expectedResult.getByte(idx), f.getByte(), epsilon)); } } - } else if (dtype == TBool.DTYPE) { + } else if (dtype == TBool.class) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + try (TBool result = + this.getGraphSession().runner().fetch(input).run().get(0); + TBool expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %b <==> %b\n", expectedResult.data().getBoolean(), result.data().getBoolean()); + "0). %b <==> %b\n", expectedResult.getBoolean(), result.getBoolean()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %b <==> %b\n", index.getAndIncrement(), - finalExpected.data().getBoolean(idx), + finalExpected.asTensor().getBoolean(idx), f.getBoolean())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + try (TBool result = this.getGraphSession().runner().fetch(input).run().get(0); + TBool expectedResult = this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getBoolean(), result.data().getBoolean()); + assertEquals(expectedResult.getBoolean(), result.getBoolean()); } else { result - .data() .scalars() .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.data().getBoolean(idx), f.getBoolean())); + (idx, f) -> assertEquals(expectedResult.getBoolean(idx), f.getBoolean())); } } - } else if (dtype == TString.DTYPE) { + } else if (dtype == TString.class) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + this.getGraphSession().runner().fetch(input).run().get(0); + TString expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %s <==> %s\n", expectedResult.data().getObject(), result.data().getObject()); + "0). %s <==> %s\n", expectedResult.getObject(), result.getObject()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %s <==> %s\n", index.getAndIncrement(), - finalExpected.data().getObject(idx), + finalExpected.asTensor().getObject(idx), f.getObject())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + this.getGraphSession().runner().fetch(input).run().get(0); + TString expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getObject(), result.data().getObject()); + assertEquals(expectedResult.getObject(), result.getObject()); } else { result - .data() .scalars() .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.data().getObject(idx), f.getObject())); + (idx, f) -> assertEquals(expectedResult.getObject(idx), f.getObject())); } } } else { - fail("Unexpected DataType: " + dtype); + fail("Unexpected Class: " + dtype); } } @@ -816,15 +769,14 @@ public void evaluateString(Output input, Predicate predicate) { boolean isScalar = input.shape().equals(Shape.scalar()); AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( "0). %b <==> %s\n", - predicate.test(result.data().getObject()), result.data().getObject()); + predicate.test(result.getObject()), result.getObject()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> @@ -835,13 +787,12 @@ public void evaluateString(Output input, Predicate predicate) { } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertTrue(predicate.test(result.data().getObject())); + assertTrue(predicate.test(result.getObject())); } else { result - .data() .scalars() .forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); } @@ -852,19 +803,18 @@ public void evaluateString(Output input, Predicate predicate) { @Override public void evaluate(Output input, Predicate predicate) { AtomicInteger index = new AtomicInteger(); - DataType dtype = input.asOutput().dataType(); + Class dtype = input.asOutput().type(); boolean isScalar = input.shape().equals(Shape.scalar()); - if (dtype == TFloat32.DTYPE) { + if (dtype == TFloat32.class) { if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( "0). %b <==> %f\n", - predicate.test(result.data().getFloat()), result.data().getFloat()); + predicate.test(result.getFloat()), result.getFloat()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> @@ -875,28 +825,26 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertTrue(predicate.test(result.data().getFloat())); + assertTrue(predicate.test(result.getFloat())); } else { result - .data() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.data().getFloat()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getFloat()))); } } - } else if (dtype == TFloat64.DTYPE) { + } else if (dtype == TFloat64.class) { if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( "0). %b <==> %f\n", - predicate.test(result.data().getDouble()), result.data().getDouble()); + predicate.test(result.getDouble()), result.getDouble()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> @@ -907,27 +855,25 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertTrue(predicate.test(result.data().getDouble())); + assertTrue(predicate.test(result.getDouble())); } else { result - .data() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.data().getDouble()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getDouble()))); } } - } else if (dtype == TInt32.DTYPE) { + } else if (dtype == TInt32.class) { if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %b <==> %d\n", predicate.test(result.data().getInt()), result.data().getInt()); + "0). %b <==> %d\n", predicate.test(result.getInt()), result.getInt()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> @@ -938,28 +884,26 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertTrue(predicate.test(result.data().getInt())); + assertTrue(predicate.test(result.getInt())); } else { result - .data() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.data().getInt()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getInt()))); } } - } else if (dtype == TInt64.DTYPE) { + } else if (dtype == TInt64.class) { if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( "0). %b <==> %d\n", - predicate.test(result.data().getLong()), result.data().getLong()); + predicate.test(result.getLong()), result.getLong()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> @@ -970,28 +914,26 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertTrue(predicate.test(result.data().getLong())); + assertTrue(predicate.test(result.getLong())); } else { result - .data() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.data().getLong()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getLong()))); } } - } else if (dtype == TUint8.DTYPE) { + } else if (dtype == TUint8.class) { if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( "0). %b <==> %d\n", - predicate.test(result.data().getByte()), result.data().getByte()); + predicate.test(result.getByte()), result.getByte()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> @@ -1002,19 +944,18 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertTrue(predicate.test(result.data().getByte())); + assertTrue(predicate.test(result.getByte())); } else { result - .data() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.data().getByte()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getByte()))); } } } else { - fail("Unexpected DataType: " + dtype); + fail("Unexpected Class: " + dtype); } } @@ -1023,119 +964,112 @@ public void evaluate(Output input, Predicate predic public void print(PrintWriter writer, Output input) { boolean isScalar = input.asOutput().shape().size() == 1; - DataType dtype = input.dataType(); - if (dtype == TFloat32.DTYPE) { + Class dtype = input.type(); + if (dtype == TFloat32.class) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - writer.printf("%d). %f\n", index.getAndIncrement(), result.data().getFloat()); + writer.printf("%d). %f\n", index.getAndIncrement(), result.getFloat()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } } - } else if (dtype == TFloat64.DTYPE) { + } else if (dtype == TFloat64.class) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %f\n", index.getAndIncrement(), ((Output) input).data().getDouble()); + "%d). %f\n", index.getAndIncrement(), ((Output) input).asTensor().getDouble()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } } - } else if (dtype == TInt32.DTYPE) { + } else if (dtype == TInt32.class) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %d\n", index.getAndIncrement(), ((Output) input).data().getInt()); + "%d). %d\n", index.getAndIncrement(), ((Output) input).asTensor().getInt()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } } - } else if (dtype == TInt64.DTYPE) { + } else if (dtype == TInt64.class) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %d\n", index.getAndIncrement(), ((Output) input).data().getLong()); + "%d). %d\n", index.getAndIncrement(), ((Output) input).asTensor().getLong()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } } - } else if (dtype == TUint8.DTYPE) { + } else if (dtype == TUint8.class) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %x\n", index.getAndIncrement(), ((Output) input).data().getByte()); + "%d). %x\n", index.getAndIncrement(), ((Output) input).asTensor().getByte()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); } } - } else if (dtype == TBool.DTYPE) { + } else if (dtype == TBool.class) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + try (TBool result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %b\n", index.getAndIncrement(), ((Output) input).data().getBoolean()); + "%d). %b\n", index.getAndIncrement(), ((Output) input).asTensor().getBoolean()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); } } - } else if (dtype == TString.DTYPE) { + } else if (dtype == TString.class) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %s\n", index.getAndIncrement(), ((Output) input).data().getObject()); + "%d). %s\n", index.getAndIncrement(), ((Output) input).asTensor().getObject()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); } } } else { - writer.println("Unexpected DataType: " + dtype); + writer.println("Unexpected Class: " + dtype); } writer.flush(); }