Skip to content

TensorFlow type system refactoring #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
  •  
  •  
  •  
24 changes: 1 addition & 23 deletions ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,29 +55,7 @@
*
* @param <T> the type of values to be mapped
*/
public interface NdArray<T> {

/**
* @return the shape of this N-dimensional array
*/
Shape shape();

/**
* @return the rank of this N-dimensional array
*/
default int rank() {
return shape().numDimensions();
}

/**
* Computes and returns the total size of this N-dimensional array, in number of values.
*
* <p>For example, given a 3x3x2 matrix, the return value will be 18.
* @return total size of this nd array
*/
default long size() {
return shape().size();
}
public interface NdArray<T> extends NdArrayBase {

/**
* Returns a sequence of all elements at a given dimension.
Expand Down
47 changes: 47 additions & 0 deletions ndarray/src/main/java/org/tensorflow/ndarray/NdArrayBase.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
package org.tensorflow.ndarray;

import java.util.function.BiConsumer;
import java.util.function.Consumer;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.index.Index;

public interface NdArrayBase {

/**
* @return the shape of this N-dimensional array
*/
Shape shape();

/**
* @return the rank of this N-dimensional array
*/
default int rank() {
return shape().numDimensions();
}

/**
* Computes and returns the total size of this N-dimensional array, in number of values.
*
* <p>For example, given a 3x3x2 matrix, the return value will be 18.
* @return total size of this nd array
*/
default long size() {
return shape().size();
}
}
8 changes: 4 additions & 4 deletions ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public static ByteNdArray vectorOf(byte... values) {
if (values == null) {
throw new IllegalArgumentException("Values cannot be null");
}
return wrap(DataBuffers.of(values, false, false), Shape.of(values.length));
return wrap(Shape.of(values.length), DataBuffers.of(values, false, false));
}

/**
Expand All @@ -81,19 +81,19 @@ public static ByteNdArray ofBytes(Shape shape) {
if (shape == null) {
throw new IllegalArgumentException("Shape cannot be null");
}
return wrap(DataBuffers.ofBytes(shape.size()), shape);
return wrap(shape, DataBuffers.ofBytes(shape.size()));
}

/**
* Wraps a buffer in a byte N-dimensional array of a given shape.
*
* @param buffer buffer to wrap
* @param shape shape of the array
* @param buffer buffer to wrap
* @return new byte N-dimensional array
* @throws IllegalArgumentException if shape is null, has unknown dimensions or has size bigger
* in the buffer size
*/
public static ByteNdArray wrap(ByteDataBuffer buffer, Shape shape) {
public static ByteNdArray wrap(Shape shape, ByteDataBuffer buffer) {
return ByteDenseNdArray.create(buffer, shape);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ class Type {
static Type IterableOf(const Type& type) {
return Interface("Iterable").add_parameter(type);
}
static Type DataTypeOf(const Type& type) {
return Class("DataType", "org.tensorflow").add_parameter(type);
}
static Type ForDataType(DataType data_type) {
switch (data_type) {
case DataType::DT_BOOL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
if (attr.type().kind() == Type::GENERIC &&
default_types.find(attr.type().name()) != default_types.end()) {
factory_statement << default_types.at(attr.type().name()).name()
<< ".DTYPE";
<< ".class";
} else {
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
factory_statement << attr.var().name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,19 @@ class TypeResolver {
std::pair<Type, Type> MakeTypePair(const Type& type) {
return std::make_pair(type, type);
}
Type NextGeneric() {
Type NextGeneric(const OpDef_AttrDef& attr_def) {
char generic_letter = next_generic_letter_++;
if (next_generic_letter_ > 'Z') {
next_generic_letter_ = 'A';
}
return Type::Generic(string(1, generic_letter))
.add_supertype(Type::Class("TType", "org.tensorflow.types.family"));
return Type::Generic(string(1, generic_letter));
}
Type TypeFamilyOf(const OpDef_AttrDef& attr_def) {
// TODO(karllessard) support more type families
if (IsRealNumbers(attr_def.allowed_values())) {
return Type::Interface("TNumber", "org.tensorflow.types.family");
}
return Type::Interface("TType", "org.tensorflow.types.family");
}
};

Expand Down Expand Up @@ -152,15 +158,12 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
types = MakeTypePair(Type::Class("Shape", "org.tensorflow.ndarray"));

} else if (attr_type == "tensor") {
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
.add_parameter(Type::Wildcard()));
types = MakeTypePair(Type::Class("TType", "org.tensorflow.types.family"));

} else if (attr_type == "type") {
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
if (IsRealNumbers(attr_def.allowed_values())) {
type.add_supertype(Type::Class("TNumber", "org.tensorflow.types.family"));
}
types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));
Type type = *iterable_out ? Type::Wildcard() : NextGeneric(attr_def);
type.add_supertype(TypeFamilyOf(attr_def));
types = MakeTypePair(type, Type::Class("Class"));

} else {
LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
Expand Down Expand Up @@ -306,7 +309,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
bool iterable = false;
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
Type var_type = types.first.kind() == Type::GENERIC
? Type::DataTypeOf(types.first)
? Type::ClassOf(types.first)
: types.first;
if (iterable) {
var_type = Type::ListOf(var_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) {
SourceWriter& SourceWriter::AppendType(const Type& type) {
if (type.wildcard()) {
Append("?");
WriteTypeBounds(type.supertypes());
} else {
Append(type.name());
if (!type.parameters().empty()) {
Expand Down Expand Up @@ -321,14 +322,27 @@ SourceWriter& SourceWriter::WriteGenerics(
Append(", ");
}
Append(pt->name());
if (!pt->supertypes().empty()) {
Append(" extends ").AppendType(pt->supertypes().front());
}
WriteTypeBounds(pt->supertypes());
first = false;
}
return Append(">");
}

SourceWriter& SourceWriter::WriteTypeBounds(
const std::list<Type>& bounds) {
bool first = true;
for (const Type& bound : bounds) {
if (first) {
Append(" extends ");
first = false;
} else {
Append(" & ");
}
AppendType(bound);
}
return *this;
}

SourceWriter::GenericNamespace* SourceWriter::PushGenericNamespace(
int modifiers) {
GenericNamespace* generic_namespace;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ class SourceWriter {
SourceWriter& WriteJavadoc(const Javadoc& javadoc);
SourceWriter& WriteAnnotations(const std::list<Annotation>& annotations);
SourceWriter& WriteGenerics(const std::list<const Type*>& generics);
SourceWriter& WriteTypeBounds(const std::list<Type>& bounds);
GenericNamespace* PushGenericNamespace(int modifiers);
void PopGenericNamespace();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
package org.tensorflow.op;

import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.data.experimental.DataServiceDataset;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;

/**
* An API for building {@code data.experimental} operations as {@link Op Op}s
Expand Down Expand Up @@ -54,7 +54,7 @@ public final class DataExperimentalOps {
public DataServiceDataset dataServiceDataset(Operand<TInt64> datasetId,
Operand<TString> processingMode, Operand<TString> address, Operand<TString> protocol,
Operand<TString> jobName, Operand<TInt64> maxOutstandingRequests, Operand<?> iterationCounter,
List<DataType<?>> outputTypes, List<Shape> outputShapes,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes,
DataServiceDataset.Options... options) {
return DataServiceDataset.create(scope, datasetId, processingMode, address, protocol, jobName, maxOutstandingRequests, iterationCounter, outputTypes, outputShapes, options);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.tensorflow.op;

import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.data.AnonymousIterator;
Expand Down Expand Up @@ -49,6 +48,7 @@
import org.tensorflow.types.TBool;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;

/**
* An API for building {@code data} operations as {@link Op Op}s
Expand All @@ -72,7 +72,7 @@ public final class DataOps {
* @param outputShapes
* @return a new instance of AnonymousIterator
*/
public AnonymousIterator anonymousIterator(List<DataType<?>> outputTypes,
public AnonymousIterator anonymousIterator(List<Class<? extends TType>> outputTypes,
List<Shape> outputShapes) {
return AnonymousIterator.create(scope, outputTypes, outputShapes);
}
Expand All @@ -90,8 +90,8 @@ public AnonymousIterator anonymousIterator(List<DataType<?>> outputTypes,
* @return a new instance of BatchDataset
*/
public BatchDataset batchDataset(Operand<?> inputDataset, Operand<TInt64> batchSize,
Operand<TBool> dropRemainder, List<DataType<?>> outputTypes, List<Shape> outputShapes,
BatchDataset.Options... options) {
Operand<TBool> dropRemainder, List<Class<? extends TType>> outputTypes,
List<Shape> outputShapes, BatchDataset.Options... options) {
return BatchDataset.create(scope, inputDataset, batchSize, dropRemainder, outputTypes, outputShapes, options);
}

Expand Down Expand Up @@ -126,7 +126,7 @@ public CSVDataset cSVDataset(Operand<TString> filenames, Operand<TString> compre
* @return a new instance of ConcatenateDataset
*/
public ConcatenateDataset concatenateDataset(Operand<?> inputDataset, Operand<?> anotherDataset,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return ConcatenateDataset.create(scope, inputDataset, anotherDataset, outputTypes, outputShapes);
}

Expand Down Expand Up @@ -161,8 +161,8 @@ public DeserializeIterator deserializeIterator(Operand<?> resourceHandle, Operan
* @param outputShapes
* @return a new instance of Iterator
*/
public Iterator iterator(String sharedName, String container, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public Iterator iterator(String sharedName, String container,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return Iterator.create(scope, sharedName, container, outputTypes, outputShapes);
}

Expand All @@ -174,8 +174,8 @@ public Iterator iterator(String sharedName, String container, List<DataType<?>>
* @param outputShapes
* @return a new instance of IteratorGetNext
*/
public IteratorGetNext iteratorGetNext(Operand<?> iterator, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public IteratorGetNext iteratorGetNext(Operand<?> iterator,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return IteratorGetNext.create(scope, iterator, outputTypes, outputShapes);
}

Expand All @@ -188,7 +188,7 @@ public IteratorGetNext iteratorGetNext(Operand<?> iterator, List<DataType<?>> ou
* @return a new instance of IteratorGetNextAsOptional
*/
public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand<?> iterator,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return IteratorGetNextAsOptional.create(scope, iterator, outputTypes, outputShapes);
}

Expand All @@ -205,8 +205,8 @@ public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand<?> iterator,
* @param outputShapes
* @return a new instance of IteratorGetNextSync
*/
public IteratorGetNextSync iteratorGetNextSync(Operand<?> iterator, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public IteratorGetNextSync iteratorGetNextSync(Operand<?> iterator,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return IteratorGetNextSync.create(scope, iterator, outputTypes, outputShapes);
}

Expand Down Expand Up @@ -252,8 +252,8 @@ public OptionalFromValue optionalFromValue(Iterable<Operand<?>> components) {
* @param outputShapes
* @return a new instance of OptionalGetValue
*/
public OptionalGetValue optionalGetValue(Operand<?> optional, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public OptionalGetValue optionalGetValue(Operand<?> optional,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return OptionalGetValue.create(scope, optional, outputTypes, outputShapes);
}

Expand Down Expand Up @@ -287,7 +287,7 @@ public OptionalNone optionalNone() {
* @return a new instance of RangeDataset
*/
public RangeDataset rangeDataset(Operand<TInt64> start, Operand<TInt64> stop,
Operand<TInt64> step, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
Operand<TInt64> step, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return RangeDataset.create(scope, start, stop, step, outputTypes, outputShapes);
}

Expand All @@ -302,7 +302,7 @@ public RangeDataset rangeDataset(Operand<TInt64> start, Operand<TInt64> stop,
* @return a new instance of RepeatDataset
*/
public RepeatDataset repeatDataset(Operand<?> inputDataset, Operand<TInt64> count,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return RepeatDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
}

Expand All @@ -329,7 +329,7 @@ public SerializeIterator serializeIterator(Operand<?> resourceHandle,
* @return a new instance of SkipDataset
*/
public SkipDataset skipDataset(Operand<?> inputDataset, Operand<TInt64> count,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return SkipDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
}

Expand All @@ -345,7 +345,7 @@ public SkipDataset skipDataset(Operand<?> inputDataset, Operand<TInt64> count,
* @return a new instance of TakeDataset
*/
public TakeDataset takeDataset(Operand<?> inputDataset, Operand<TInt64> count,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return TakeDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
}

Expand Down Expand Up @@ -406,8 +406,8 @@ public TfRecordDataset tfRecordDataset(Operand<TString> filenames,
* @param outputShapes
* @return a new instance of ZipDataset
*/
public ZipDataset zipDataset(Iterable<Operand<?>> inputDatasets, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public ZipDataset zipDataset(Iterable<Operand<?>> inputDatasets,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return ZipDataset.create(scope, inputDatasets, outputTypes, outputShapes);
}
}
Loading