Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9e92687
Initial commit of gradient descent optimizers.
Craigacp Nov 5, 2019
68a353c
Adding Apache 2.0 license header to all optimizer files.
Craigacp Nov 5, 2019
b3f4be8
Bug fix for the MNISTTest.
Craigacp Dec 6, 2019
d1868ea
Refactor to uptake latest tensorflow-core changes.
Craigacp Jan 31, 2020
6d189cc
Added type safety and updates for new api.
Craigacp Jan 31, 2020
53e438a
Small changes, plus a fix for DataTypes to include references to the …
Craigacp Jan 31, 2020
83140b4
Repackaging the optimizers into tensorflow-training, org.tensorflow.t…
Craigacp Feb 7, 2020
b2ac923
Initial commit of gradient descent optimizers.
Craigacp Nov 5, 2019
e7eb2e8
Adding Apache 2.0 license header to all optimizer files.
Craigacp Nov 5, 2019
3d63564
Bug fix for the MNISTTest.
Craigacp Dec 6, 2019
ed71dc5
Refactor to uptake latest tensorflow-core changes.
Craigacp Jan 31, 2020
b054449
Added type safety and updates for new api.
Craigacp Jan 31, 2020
b29be50
Repackaging the optimizers into tensorflow-training, org.tensorflow.t…
Craigacp Feb 7, 2020
6cdb55c
Delete pom.xml
Craigacp Feb 8, 2020
b9d64c5
Googlify with IntelliJ's Google Java Style Guide formatter.
Craigacp Feb 11, 2020
6ae5ace
Bumping the copyright year, and switching to try-with-resources in th…
Craigacp Feb 12, 2020
56429e8
Updating variableWithInit to use @Endpoint.
Craigacp Feb 25, 2020
5d8cb69
Refactorings after code review.
Craigacp Feb 25, 2020
51f5d47
Adding a couple of lines to the gitignore.
Craigacp Feb 25, 2020
66876ed
Adding a bit of documentation, threading the named operations through…
Craigacp Feb 25, 2020
7a2fd25
Adding a guard to prevent variableWithInit being called on an EagerSe…
Craigacp Feb 25, 2020
1b98f52
Update Ops.java
karllessard Mar 2, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ xcuserdata/**
/estimator_api_init_files_list.txt
*.whl

# Patch files
*.orig
*.rej

# Android
.gradle
.idea
Expand Down
11 changes: 11 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
<modules>
<module>tensorflow-tools</module>
<module>tensorflow-core</module>
<module>tensorflow-training</module>
</modules>

<properties>
Expand All @@ -39,6 +40,7 @@
<maven.compiler.target>1.8</maven.compiler.target>
<junit.version>4.12</junit.version>
<jmh.version>1.21</jmh.version>
<skipAssembly>true</skipAssembly>
</properties>

<dependencyManagement>
Expand Down Expand Up @@ -120,6 +122,15 @@
</executions>
</plugin-->
</plugins>
<pluginManagement>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.2.0</version>
</plugin>
</plugins>
</pluginManagement>
</build>

</project>
Expand Down
9 changes: 9 additions & 0 deletions tensorflow-core/tensorflow-core-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,15 @@
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<version>3.2.0</version>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
import org.tensorflow.op.core.Gradients;
import org.tensorflow.op.core.GuaranteeConst;
import org.tensorflow.op.core.HashTable;
import org.tensorflow.op.core.Helpers;
import org.tensorflow.op.core.HistogramFixedWidth;
import org.tensorflow.op.core.Identity;
import org.tensorflow.op.core.IdentityN;
Expand Down Expand Up @@ -1846,7 +1847,7 @@ public Gradients gradients(Iterable<? extends Operand<?>> y, Iterable<? extends
* Example of usage:
* <pre>{@code
* Gradients gradients = tf.gradients(loss, Arrays.asList(w, b));
* Scalar<TFloat32> alpha = ops.scalar(1.0f);
* Constant<TFloat32> alpha = tf.val(1.0f);
* tf.train.applyGradientDescent(w, alpha, gradients.<Float>dy(0));
* tf.train.applyGradientDescent(b, alpha, gradients.<Float>dy(1));
* }</pre>
Expand Down Expand Up @@ -7332,6 +7333,21 @@ public VarIsInitializedOp varIsInitializedOp(Operand<?> resource) {
return VarIsInitializedOp.create(scope, resource);
}

/**
* Factory method to create a new Variable with its initializer.
* <p>
* Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op
* does not work in an EagerSession.
*
* @param scope current scope
* @param init The op to use to initialise this variable.
* @param options carries optional attributes values
* @return a new instance of Variable
*/
public <T extends TType> Variable<T> variable(Operand<T> init, Variable.Options... options) {
return Helpers.createVariableWithInit(scope, init, options);
}

/**
* Holds state in the form of a tensor that persists across steps.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,6 @@ static DataType<?> fromNativeCode(int nativeCode) {
// to allow user to register custom data types?
private static void register(DataType<?> dataType) {
DATA_TYPE_REGISTRY.put(dataType.nativeCode(), dataType);
DATA_TYPE_REGISTRY.put(dataType.nativeCode() + 100, dataType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewGraph;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewWhile;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
Expand All @@ -38,6 +40,9 @@
import org.tensorflow.internal.c_api.TF_Output;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_WhileParams;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.NoOp;


/**
* A data flow graph representing a TensorFlow computation.
Expand All @@ -49,6 +54,8 @@
*/
public final class Graph implements ExecutionEnvironment, AutoCloseable {

public static final String DEFAULT_INIT_NAME = "init";

/** Create an empty Graph. */
public Graph() {
nativeHandle = allocate();
Expand Down Expand Up @@ -166,6 +173,28 @@ public byte[] toGraphDef() {
}
}

/**
* Adds an initializer to the graph initializer list.
* @param initializer An initializer to add to the list.
*/
public synchronized void addInitializer(Operand<?> initializer) {
initializers.add(initializer);
}

/**
* Returns an op which initializers all the variables.
* @return The initializer operation.
*/
public NoOp variablesInitializer() {
return variablesInitializer(DEFAULT_INIT_NAME);
}

public NoOp variablesInitializer(String name) {
Scope scope = new Scope(this);
scope = scope.withName(name).withControlDependencies(initializers);
return NoOp.create(scope);
}

/**
* Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e.,
* {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
Expand Down Expand Up @@ -378,6 +407,8 @@ public Output<?>[] whileLoop(
private TF_Graph nativeHandle;
private int refcount = 0;

private final List<Operand<?>> initializers = new ArrayList<>();

// Related native objects (such as the TF_Operation object backing an Operation instance)
// have a validity tied to that of the Graph. The handles to those native objects are not
// valid after Graph.close() has been invoked.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ public final class Tensor<T extends TType> implements AutoCloseable {
* @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type
* system.
*/
@SuppressWarnings("unchecked")
public static <T extends TType> Tensor<T> create(Object obj, DataType<T> dtype) {
if (!objectCompatWithType(obj, dtype)) {
throw new IllegalArgumentException(
Expand All @@ -158,7 +157,7 @@ public static <T extends TType> Tensor<T> create(Object obj, DataType<T> dtype)
}
long[] dimSizes = new long[numDimensions(obj, dtype)];
fillShape(obj, 0, dimSizes);
Tensor<T> t = new Tensor(dtype, Shape.of(dimSizes));
Tensor<T> t = new Tensor<>(dtype, Shape.of(dimSizes));
TF_Tensor nativeHandle;
if (t.dtype != TString.DTYPE) {
long byteSize = elemByteSize(t.dtype) * t.shape.size();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.tensorflow.op.core;

import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Scope;
import org.tensorflow.op.annotation.Endpoint;
import org.tensorflow.op.annotation.Operator;
import org.tensorflow.types.family.TType;

/**
* Container class for core methods which add or perform several operations
* and return one of them.
*/
@Operator
public abstract class Helpers {

/**
* This class contains static factories.
*/
private Helpers() {}

/**
* Factory method to create a new Variable with it's initializer.
* <p>
* Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op
* does not work in an EagerSession.
* @param scope current scope
* @param init The op to use to initialise this variable.
* @param options carries optional attributes values
* @return a new instance of Variable
*/
@Endpoint(name="variable")
public static <T extends TType> Variable<T> createVariableWithInit(Scope scope, Operand<T> init, Variable.Options... options) {
Output<T> initOutput = init.asOutput();
Variable<T> newVar = Variable.create(scope,initOutput.shape(),initOutput.dataType(),options);
Assign<T> assignOp = Assign.create(scope,newVar,init);
ExecutionEnvironment exEnv = scope.env();
if (exEnv instanceof Graph) {
Graph graph = (Graph) exEnv;
graph.addInitializer(assignOp);
} else {
throw new IllegalArgumentException("variable with init is only supported on Graph sessions.");
}

return newVar;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,8 @@ public static <T> NdArray<T> scalarOfObject(T value) {
*/
@SafeVarargs
public static <T> NdArray<T> vectorOfObjects(T... values) {
if (values == null) {
throw new IllegalArgumentException();
if (values == null || values.length == 0) {
throw new IllegalArgumentException("Null or zero length input supplied to vectorOfObjects.");
}
return wrap(Shape.of(values.length), DataBuffers.from(values, false, false));
}
Expand Down
77 changes: 77 additions & 0 deletions tensorflow-training/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
<!--
Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-java</artifactId>
<version>0.1.0-SNAPSHOT</version>
</parent>
<artifactId>tensorflow-training</artifactId>
<packaging>jar</packaging>

<name>TensorFlow Training Library</name>
<description>
Operations for training Tensorflow models.
</description>

<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-generator-annprocess</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.2</version>
<configuration>
<forkCount>1</forkCount>
<reuseForks>false</reuseForks>
<argLine>-Xmx2G -XX:MaxPermSize=256m</argLine>
<skipTests>false</skipTests>
<includes>
<include>**/*Test.java</include>
</includes>
</configuration>
</plugin>
</plugins>
</build>

</project>
Loading