1919package org .apache .beam .runners .core .construction ;
2020
2121import static com .google .common .base .Preconditions .checkArgument ;
22+ import static com .google .common .base .Preconditions .checkState ;
2223import static org .apache .beam .runners .core .construction .PTransformTranslation .PAR_DO_TRANSFORM_URN ;
2324
2425import com .google .auto .service .AutoService ;
2526import com .google .auto .value .AutoValue ;
27+ import com .google .common .annotations .VisibleForTesting ;
2628import com .google .common .base .Optional ;
2729import com .google .common .collect .Iterables ;
2830import com .google .common .collect .Sets ;
4648import org .apache .beam .sdk .common .runner .v1 .RunnerApi .SdkFunctionSpec ;
4749import org .apache .beam .sdk .common .runner .v1 .RunnerApi .SideInput ;
4850import org .apache .beam .sdk .common .runner .v1 .RunnerApi .SideInput .Builder ;
49- import org .apache .beam .sdk .common .runner .v1 .RunnerApi .StateSpec ;
50- import org .apache .beam .sdk .common .runner .v1 .RunnerApi .TimerSpec ;
5151import org .apache .beam .sdk .runners .AppliedPTransform ;
52+ import org .apache .beam .sdk .state .StateSpec ;
53+ import org .apache .beam .sdk .state .StateSpecs ;
54+ import org .apache .beam .sdk .state .TimeDomain ;
55+ import org .apache .beam .sdk .state .TimerSpec ;
56+ import org .apache .beam .sdk .transforms .Combine ;
5257import org .apache .beam .sdk .transforms .DoFn ;
5358import org .apache .beam .sdk .transforms .Materializations ;
5459import org .apache .beam .sdk .transforms .PTransform ;
@@ -107,7 +112,8 @@ public String getUrn(ParDo.MultiOutput<?, ?> transform) {
107112
108113 @ Override
109114 public FunctionSpec translate (
110- AppliedPTransform <?, ?, MultiOutput <?, ?>> transform , SdkComponents components ) {
115+ AppliedPTransform <?, ?, MultiOutput <?, ?>> transform , SdkComponents components )
116+ throws IOException {
111117 ParDoPayload payload = toProto (transform .getTransform (), components );
112118 return RunnerApi .FunctionSpec .newBuilder ()
113119 .setUrn (PAR_DO_TRANSFORM_URN )
@@ -128,8 +134,10 @@ public static class Registrar implements TransformPayloadTranslatorRegistrar {
128134 }
129135 }
130136
131- public static ParDoPayload toProto (ParDo .MultiOutput <?, ?> parDo , SdkComponents components ) {
132- DoFnSignature signature = DoFnSignatures .getSignature (parDo .getFn ().getClass ());
137+ public static ParDoPayload toProto (ParDo .MultiOutput <?, ?> parDo , SdkComponents components )
138+ throws IOException {
139+ DoFn <?, ?> doFn = parDo .getFn ();
140+ DoFnSignature signature = DoFnSignatures .getSignature (doFn .getClass ());
133141 Map <String , StateDeclaration > states = signature .stateDeclarations ();
134142 Map <String , TimerDeclaration > timers = signature .timerDeclarations ();
135143 List <Parameter > parameters = signature .processElement ().extraParameters ();
@@ -146,16 +154,62 @@ public static ParDoPayload toProto(ParDo.MultiOutput<?, ?> parDo, SdkComponents
146154 }
147155 }
148156 for (Map .Entry <String , StateDeclaration > state : states .entrySet ()) {
149- StateSpec spec = toProto (state .getValue ());
157+ RunnerApi .StateSpec spec =
158+ toProto (getStateSpecOrCrash (state .getValue (), doFn ), components );
150159 builder .putStateSpecs (state .getKey (), spec );
151160 }
152161 for (Map .Entry <String , TimerDeclaration > timer : timers .entrySet ()) {
153- TimerSpec spec = toProto (timer .getValue ());
162+ RunnerApi .TimerSpec spec =
163+ toProto (getTimerSpecOrCrash (timer .getValue (), doFn ));
154164 builder .putTimerSpecs (timer .getKey (), spec );
155165 }
156166 return builder .build ();
157167 }
158168
169+ private static StateSpec <?> getStateSpecOrCrash (
170+ StateDeclaration stateDeclaration , DoFn <?, ?> target ) {
171+ try {
172+ Object fieldValue = stateDeclaration .field ().get (target );
173+ checkState (fieldValue instanceof StateSpec ,
174+ "Malformed %s class %s: state declaration field %s does not have type %s." ,
175+ DoFn .class .getSimpleName (),
176+ target .getClass ().getName (),
177+ stateDeclaration .field ().getName (),
178+ StateSpec .class );
179+
180+ return (StateSpec <?>) stateDeclaration .field ().get (target );
181+ } catch (IllegalAccessException exc ) {
182+ throw new RuntimeException (
183+ String .format (
184+ "Malformed %s class %s: state declaration field %s is not accessible." ,
185+ DoFn .class .getSimpleName (),
186+ target .getClass ().getName (),
187+ stateDeclaration .field ().getName ()));
188+ }
189+ }
190+
191+ private static TimerSpec getTimerSpecOrCrash (
192+ TimerDeclaration timerDeclaration , DoFn <?, ?> target ) {
193+ try {
194+ Object fieldValue = timerDeclaration .field ().get (target );
195+ checkState (fieldValue instanceof TimerSpec ,
196+ "Malformed %s class %s: timer declaration field %s does not have type %s." ,
197+ DoFn .class .getSimpleName (),
198+ target .getClass ().getName (),
199+ timerDeclaration .field ().getName (),
200+ TimerSpec .class );
201+
202+ return (TimerSpec ) timerDeclaration .field ().get (target );
203+ } catch (IllegalAccessException exc ) {
204+ throw new RuntimeException (
205+ String .format (
206+ "Malformed %s class %s: timer declaration field %s is not accessible." ,
207+ DoFn .class .getSimpleName (),
208+ target .getClass ().getName (),
209+ timerDeclaration .field ().getName ()));
210+ }
211+ }
212+
159213 public static DoFn <?, ?> getDoFn (ParDoPayload payload ) throws InvalidProtocolBufferException {
160214 return doFnAndMainOutputTagFromProto (payload .getDoFn ()).getDoFn ();
161215 }
@@ -179,14 +233,149 @@ public static RunnerApi.PCollection getMainInput(
179233 return components .getPcollectionsOrThrow (ptransform .getInputsOrThrow (mainInputId ));
180234 }
181235
182- // TODO: Implement
183- private static StateSpec toProto (StateDeclaration state ) {
184- throw new UnsupportedOperationException ("Not yet supported" );
236+ @ VisibleForTesting
237+ static RunnerApi .StateSpec toProto (StateSpec <?> stateSpec , final SdkComponents components )
238+ throws IOException {
239+ final RunnerApi .StateSpec .Builder builder = RunnerApi .StateSpec .newBuilder ();
240+
241+ return stateSpec .match (
242+ new StateSpec .Cases <RunnerApi .StateSpec >() {
243+ @ Override
244+ public RunnerApi .StateSpec dispatchValue (Coder <?> valueCoder ) {
245+ return builder
246+ .setValueSpec (
247+ RunnerApi .ValueStateSpec .newBuilder ()
248+ .setCoderId (registerCoderOrThrow (components , valueCoder )))
249+ .build ();
250+ }
251+
252+ @ Override
253+ public RunnerApi .StateSpec dispatchBag (Coder <?> elementCoder ) {
254+ return builder
255+ .setBagSpec (
256+ RunnerApi .BagStateSpec .newBuilder ()
257+ .setElementCoderId (registerCoderOrThrow (components , elementCoder )))
258+ .build ();
259+ }
260+
261+ @ Override
262+ public RunnerApi .StateSpec dispatchCombining (
263+ Combine .CombineFn <?, ?, ?> combineFn , Coder <?> accumCoder ) {
264+ return builder
265+ .setCombiningSpec (
266+ RunnerApi .CombiningStateSpec .newBuilder ()
267+ .setAccumulatorCoderId (registerCoderOrThrow (components , accumCoder ))
268+ .setCombineFn (CombineTranslation .toProto (combineFn )))
269+ .build ();
270+ }
271+
272+ @ Override
273+ public RunnerApi .StateSpec dispatchMap (Coder <?> keyCoder , Coder <?> valueCoder ) {
274+ return builder
275+ .setMapSpec (
276+ RunnerApi .MapStateSpec .newBuilder ()
277+ .setKeyCoderId (registerCoderOrThrow (components , keyCoder ))
278+ .setValueCoderId (registerCoderOrThrow (components , valueCoder )))
279+ .build ();
280+ }
281+
282+ @ Override
283+ public RunnerApi .StateSpec dispatchSet (Coder <?> elementCoder ) {
284+ return builder
285+ .setSetSpec (
286+ RunnerApi .SetStateSpec .newBuilder ()
287+ .setElementCoderId (registerCoderOrThrow (components , elementCoder )))
288+ .build ();
289+ }
290+ });
291+ }
292+
293+ @ VisibleForTesting
294+ static StateSpec <?> fromProto (RunnerApi .StateSpec stateSpec , RunnerApi .Components components )
295+ throws IOException {
296+ switch (stateSpec .getSpecCase ()) {
297+ case VALUE_SPEC :
298+ return StateSpecs .value (
299+ CoderTranslation .fromProto (
300+ components .getCodersMap ().get (stateSpec .getValueSpec ().getCoderId ()), components ));
301+ case BAG_SPEC :
302+ return StateSpecs .bag (
303+ CoderTranslation .fromProto (
304+ components .getCodersMap ().get (stateSpec .getBagSpec ().getElementCoderId ()),
305+ components ));
306+ case COMBINING_SPEC :
307+ FunctionSpec combineFnSpec = stateSpec .getCombiningSpec ().getCombineFn ().getSpec ();
308+
309+ if (!combineFnSpec .getUrn ().equals (CombineTranslation .JAVA_SERIALIZED_COMBINE_FN_URN )) {
310+ throw new UnsupportedOperationException (
311+ String .format (
312+ "Cannot create %s from non-Java %s: %s" ,
313+ StateSpec .class .getSimpleName (),
314+ Combine .CombineFn .class .getSimpleName (),
315+ combineFnSpec .getUrn ()));
316+ }
317+
318+ Combine .CombineFn <?, ?, ?> combineFn =
319+ (Combine .CombineFn <?, ?, ?>)
320+ SerializableUtils .deserializeFromByteArray (
321+ combineFnSpec .getParameter ().unpack (BytesValue .class ).toByteArray (),
322+ Combine .CombineFn .class .getSimpleName ());
323+
324+ // Rawtype coder cast because it is required to be a valid accumulator coder
325+ // for the CombineFn, by construction
326+ return StateSpecs .combining (
327+ (Coder )
328+ CoderTranslation .fromProto (
329+ components
330+ .getCodersMap ()
331+ .get (stateSpec .getCombiningSpec ().getAccumulatorCoderId ()),
332+ components ),
333+ combineFn );
334+
335+ case MAP_SPEC :
336+ return StateSpecs .map (
337+ CoderTranslation .fromProto (
338+ components .getCodersOrThrow (stateSpec .getMapSpec ().getKeyCoderId ()), components ),
339+ CoderTranslation .fromProto (
340+ components .getCodersOrThrow (stateSpec .getMapSpec ().getValueCoderId ()), components ));
341+
342+ case SET_SPEC :
343+ return StateSpecs .set (
344+ CoderTranslation .fromProto (
345+ components .getCodersMap ().get (stateSpec .getSetSpec ().getElementCoderId ()),
346+ components ));
347+
348+ case SPEC_NOT_SET :
349+ default :
350+ throw new IllegalArgumentException (
351+ String .format ("Unknown %s: %s" , RunnerApi .StateSpec .class .getName (), stateSpec ));
352+
353+ }
354+ }
355+
356+ private static String registerCoderOrThrow (SdkComponents components , Coder coder ) {
357+ try {
358+ return components .registerCoder (coder );
359+ } catch (IOException exc ) {
360+ throw new RuntimeException ("Failure to register coder" , exc );
361+ }
185362 }
186363
187- // TODO: Implement
188- private static TimerSpec toProto (TimerDeclaration timer ) {
189- throw new UnsupportedOperationException ("Not yet supported" );
364+ private static RunnerApi .TimerSpec toProto (TimerSpec timer ) {
365+ return RunnerApi .TimerSpec .newBuilder ().setTimeDomain (toProto (timer .getTimeDomain ())).build ();
366+ }
367+
368+ private static RunnerApi .TimeDomain toProto (TimeDomain timeDomain ) {
369+ switch (timeDomain ) {
370+ case EVENT_TIME :
371+ return RunnerApi .TimeDomain .EVENT_TIME ;
372+ case PROCESSING_TIME :
373+ return RunnerApi .TimeDomain .PROCESSING_TIME ;
374+ case SYNCHRONIZED_PROCESSING_TIME :
375+ return RunnerApi .TimeDomain .SYNCHRONIZED_PROCESSING_TIME ;
376+ default :
377+ throw new IllegalArgumentException ("Unknown time domain" );
378+ }
190379 }
191380
192381 @ AutoValue
0 commit comments