1515
1616package org .tensorflow ;
1717
18+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_DeleteOp ;
19+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_DeleteTensorHandle ;
20+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_OpGetInputLength ;
21+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_OpGetOutputLength ;
22+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_TensorHandleDataType ;
23+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_TensorHandleDim ;
24+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_TensorHandleNumDims ;
25+ import static org .tensorflow .internal .c_api .global .tensorflow .TFE_TensorHandleResolve ;
26+
1827import java .util .concurrent .atomic .AtomicReferenceArray ;
28+ import org .bytedeco .javacpp .PointerScope ;
29+ import org .tensorflow .internal .c_api .TFE_Op ;
30+ import org .tensorflow .internal .c_api .TFE_TensorHandle ;
31+ import org .tensorflow .internal .c_api .TF_Status ;
32+ import org .tensorflow .internal .c_api .TF_Tensor ;
1933import org .tensorflow .tools .Shape ;
2034
2135/**
@@ -31,8 +45,8 @@ class EagerOperation extends AbstractOperation {
3145
3246 EagerOperation (
3347 EagerSession session ,
34- long opNativeHandle ,
35- long [] outputNativeHandles ,
48+ TFE_Op opNativeHandle ,
49+ TFE_TensorHandle [] outputNativeHandles ,
3650 String type ,
3751 String name ) {
3852 this .session = session ;
@@ -68,7 +82,7 @@ public int inputListLength(final String name) {
6882 }
6983
7084 @ Override
71- public long getUnsafeNativeHandle (int outputIndex ) {
85+ public TFE_TensorHandle getUnsafeNativeHandle (int outputIndex ) {
7286 return nativeRef .outputHandles [outputIndex ];
7387 }
7488
@@ -80,7 +94,7 @@ public Shape shape(int outputIndex) {
8094 if (tensor != null ) {
8195 return tensor .shape ();
8296 }
83- long outputNativeHandle = getUnsafeNativeHandle (outputIndex );
97+ TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle (outputIndex );
8498 long [] shape = new long [numDims (outputNativeHandle )];
8599 for (int i = 0 ; i < shape .length ; ++i ) {
86100 shape [i ] = dim (outputNativeHandle , i );
@@ -96,7 +110,7 @@ public DataType<?> dtype(int outputIndex) {
96110 if (tensor != null ) {
97111 return tensor .dataType ();
98112 }
99- long outputNativeHandle = getUnsafeNativeHandle (outputIndex );
113+ TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle (outputIndex );
100114 return DataTypes .fromNativeCode (dataType (outputNativeHandle ));
101115 }
102116
@@ -119,7 +133,7 @@ private Tensor<?> resolveTensor(int outputIndex) {
119133 // Take an optimistic approach, where we attempt to resolve the output tensor without locking.
120134 // If another thread has resolved it meanwhile, release our copy and reuse the existing one
121135 // instead.
122- long tensorNativeHandle = resolveTensorHandle (getUnsafeNativeHandle (outputIndex ));
136+ TF_Tensor tensorNativeHandle = resolveTensorHandle (getUnsafeNativeHandle (outputIndex ));
123137 Tensor <?> tensor = Tensor .fromHandle (tensorNativeHandle , session );
124138 if (!outputTensors .compareAndSet (outputIndex , null , tensor )) {
125139 tensor .close ();
@@ -131,43 +145,104 @@ private Tensor<?> resolveTensor(int outputIndex) {
131145 private static class NativeReference extends EagerSession .NativeReference {
132146
133147 NativeReference (
134- EagerSession session , EagerOperation operation , long opHandle , long [] outputHandles ) {
148+ EagerSession session , EagerOperation operation , TFE_Op opHandle , TFE_TensorHandle [] outputHandles ) {
135149 super (session , operation );
136150 this .opHandle = opHandle ;
137151 this .outputHandles = outputHandles ;
138152 }
139153
140154 @ Override
141155 void delete () {
142- if (opHandle != 0L ) {
156+ if (opHandle != null && ! opHandle . isNull () ) {
143157 for (int i = 0 ; i < outputHandles .length ; ++i ) {
144- if (outputHandles [i ] != 0L ) {
158+ if (outputHandles [i ] != null && ! outputHandles [ i ]. isNull () ) {
145159 EagerOperation .deleteTensorHandle (outputHandles [i ]);
146- outputHandles [i ] = 0L ;
160+ outputHandles [i ] = null ;
147161 }
148162 }
149163 EagerOperation .delete (opHandle );
150- opHandle = 0L ;
164+ opHandle = null ;
151165 }
152166 }
153167
154- private long opHandle ;
155- private final long [] outputHandles ;
168+ private TFE_Op opHandle ;
169+ private final TFE_TensorHandle [] outputHandles ;
156170 }
157-
158- private static native void delete (long handle );
159171
160- private static native void deleteTensorHandle (long handle );
172+ private static void requireOp (TFE_Op handle ) {
173+ if (handle == null || handle .isNull ()) {
174+ throw new IllegalStateException ("Eager session has been closed" );
175+ }
176+ }
161177
162- private static native long resolveTensorHandle (long handle );
178+ private static void requireTensorHandle (TFE_TensorHandle handle ) {
179+ if (handle == null || handle .isNull ()) {
180+ throw new IllegalStateException ("EagerSession has been closed" );
181+ }
182+ }
163183
164- private static native int outputListLength (long handle , String name );
184+ private static void delete (TFE_Op handle ) {
185+ if (handle == null || handle .isNull ()) return ;
186+ TFE_DeleteOp (handle );
187+ }
165188
166- private static native int inputListLength (long handle , String name );
189+ private static void deleteTensorHandle (TFE_TensorHandle handle ) {
190+ if (handle == null || handle .isNull ()) return ;
191+ TFE_DeleteTensorHandle (handle );
192+ }
167193
168- private static native int dataType (long handle );
194+ private static TF_Tensor resolveTensorHandle (TFE_TensorHandle handle ) {
195+ requireTensorHandle (handle );
196+ try (PointerScope scope = new PointerScope ()) {
197+ TF_Status status = TF_Status .newStatus ();
198+ TF_Tensor tensor = TFE_TensorHandleResolve (handle , status );
199+ status .throwExceptionIfNotOK ();
200+ return tensor ;
201+ }
202+ }
169203
170- private static native int numDims (long handle );
204+ private static int outputListLength (TFE_Op handle , String name ) {
205+ requireOp (handle );
206+ try (PointerScope scope = new PointerScope ()) {
207+ TF_Status status = TF_Status .newStatus ();
208+ int length = TFE_OpGetOutputLength (handle , name , status );
209+ status .throwExceptionIfNotOK ();
210+ return length ;
211+ }
212+ }
171213
172- private static native long dim (long handle , int index );
173- }
214+ private static int inputListLength (TFE_Op handle , String name ) {
215+ requireOp (handle );
216+ try (PointerScope scope = new PointerScope ()) {
217+ TF_Status status = TF_Status .newStatus ();
218+ int length = TFE_OpGetInputLength (handle , name , status );
219+ status .throwExceptionIfNotOK ();
220+ return length ;
221+ }
222+ }
223+
224+ private static int dataType (TFE_TensorHandle handle ) {
225+ requireTensorHandle (handle );
226+ return TFE_TensorHandleDataType (handle );
227+ }
228+
229+ private static int numDims (TFE_TensorHandle handle ) {
230+ requireTensorHandle (handle );
231+ try (PointerScope scope = new PointerScope ()) {
232+ TF_Status status = TF_Status .newStatus ();
233+ int numDims = TFE_TensorHandleNumDims (handle , status );
234+ status .throwExceptionIfNotOK ();
235+ return numDims ;
236+ }
237+ }
238+
239+ private static long dim (TFE_TensorHandle handle , int index ) {
240+ requireTensorHandle (handle );
241+ try (PointerScope scope = new PointerScope ()) {
242+ TF_Status status = TF_Status .newStatus ();
243+ long dim = TFE_TensorHandleDim (handle , index , status );
244+ status .throwExceptionIfNotOK ();
245+ return dim ;
246+ }
247+ }
248+ }
0 commit comments