Skip to content

Commit c687887

Browse files
committed
This closes #3233: [BEAM-115] Runner API Translations for StateSpec and TimerSpec
Implement TimerSpec and StateSpec translation Make Java serialized CombineFn URN public Add case dispatch to StateSpec Flesh out TimerSpec and StateSpec in Runner API Allow translation to throw IOException Mark CombineFnWithContext StateSpecs internal
2 parents 6bb204f + 39220db commit c687887

7 files changed

Lines changed: 478 additions & 108 deletions

File tree

runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
* RunnerApi.CombinePayload} protos.
5050
*/
5151
public class CombineTranslation {
52-
private static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:java:combinefn:v1";
52+
public static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:java:combinefn:v1";
5353

5454
public static CombinePayload toProto(
5555
AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>> combine, SdkComponents sdkComponents)
@@ -86,7 +86,7 @@ private static <K, InputT, AccumT> Coder<AccumT> extractAccumulatorCoder(
8686
.getAccumulatorCoder();
8787
}
8888

89-
private static SdkFunctionSpec toProto(GlobalCombineFn<?, ?, ?> combineFn) {
89+
public static SdkFunctionSpec toProto(GlobalCombineFn<?, ?, ?> combineFn) {
9090
return SdkFunctionSpec.newBuilder()
9191
// TODO: Set Java SDK Environment URN
9292
.setSpec(

runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ public static String urnForTransform(PTransform<?, ?> transform) {
138138
*/
139139
public interface TransformPayloadTranslator<T extends PTransform<?, ?>> {
140140
String getUrn(T transform);
141-
FunctionSpec translate(AppliedPTransform<?, ?, T> application, SdkComponents components);
141+
FunctionSpec translate(AppliedPTransform<?, ?, T> application, SdkComponents components)
142+
throws IOException;
142143
}
143144

144145
/**

runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java

Lines changed: 202 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
package org.apache.beam.runners.core.construction;
2020

2121
import static com.google.common.base.Preconditions.checkArgument;
22+
import static com.google.common.base.Preconditions.checkState;
2223
import static org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN;
2324

2425
import com.google.auto.service.AutoService;
2526
import com.google.auto.value.AutoValue;
27+
import com.google.common.annotations.VisibleForTesting;
2628
import com.google.common.base.Optional;
2729
import com.google.common.collect.Iterables;
2830
import com.google.common.collect.Sets;
@@ -46,9 +48,12 @@
4648
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SdkFunctionSpec;
4749
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput;
4850
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput.Builder;
49-
import org.apache.beam.sdk.common.runner.v1.RunnerApi.StateSpec;
50-
import org.apache.beam.sdk.common.runner.v1.RunnerApi.TimerSpec;
5151
import org.apache.beam.sdk.runners.AppliedPTransform;
52+
import org.apache.beam.sdk.state.StateSpec;
53+
import org.apache.beam.sdk.state.StateSpecs;
54+
import org.apache.beam.sdk.state.TimeDomain;
55+
import org.apache.beam.sdk.state.TimerSpec;
56+
import org.apache.beam.sdk.transforms.Combine;
5257
import org.apache.beam.sdk.transforms.DoFn;
5358
import org.apache.beam.sdk.transforms.Materializations;
5459
import org.apache.beam.sdk.transforms.PTransform;
@@ -107,7 +112,8 @@ public String getUrn(ParDo.MultiOutput<?, ?> transform) {
107112

108113
@Override
109114
public FunctionSpec translate(
110-
AppliedPTransform<?, ?, MultiOutput<?, ?>> transform, SdkComponents components) {
115+
AppliedPTransform<?, ?, MultiOutput<?, ?>> transform, SdkComponents components)
116+
throws IOException {
111117
ParDoPayload payload = toProto(transform.getTransform(), components);
112118
return RunnerApi.FunctionSpec.newBuilder()
113119
.setUrn(PAR_DO_TRANSFORM_URN)
@@ -128,8 +134,10 @@ public static class Registrar implements TransformPayloadTranslatorRegistrar {
128134
}
129135
}
130136

131-
public static ParDoPayload toProto(ParDo.MultiOutput<?, ?> parDo, SdkComponents components) {
132-
DoFnSignature signature = DoFnSignatures.getSignature(parDo.getFn().getClass());
137+
public static ParDoPayload toProto(ParDo.MultiOutput<?, ?> parDo, SdkComponents components)
138+
throws IOException {
139+
DoFn<?, ?> doFn = parDo.getFn();
140+
DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
133141
Map<String, StateDeclaration> states = signature.stateDeclarations();
134142
Map<String, TimerDeclaration> timers = signature.timerDeclarations();
135143
List<Parameter> parameters = signature.processElement().extraParameters();
@@ -146,16 +154,62 @@ public static ParDoPayload toProto(ParDo.MultiOutput<?, ?> parDo, SdkComponents
146154
}
147155
}
148156
for (Map.Entry<String, StateDeclaration> state : states.entrySet()) {
149-
StateSpec spec = toProto(state.getValue());
157+
RunnerApi.StateSpec spec =
158+
toProto(getStateSpecOrCrash(state.getValue(), doFn), components);
150159
builder.putStateSpecs(state.getKey(), spec);
151160
}
152161
for (Map.Entry<String, TimerDeclaration> timer : timers.entrySet()) {
153-
TimerSpec spec = toProto(timer.getValue());
162+
RunnerApi.TimerSpec spec =
163+
toProto(getTimerSpecOrCrash(timer.getValue(), doFn));
154164
builder.putTimerSpecs(timer.getKey(), spec);
155165
}
156166
return builder.build();
157167
}
158168

169+
private static StateSpec<?> getStateSpecOrCrash(
170+
StateDeclaration stateDeclaration, DoFn<?, ?> target) {
171+
try {
172+
Object fieldValue = stateDeclaration.field().get(target);
173+
checkState(fieldValue instanceof StateSpec,
174+
"Malformed %s class %s: state declaration field %s does not have type %s.",
175+
DoFn.class.getSimpleName(),
176+
target.getClass().getName(),
177+
stateDeclaration.field().getName(),
178+
StateSpec.class);
179+
180+
return (StateSpec<?>) stateDeclaration.field().get(target);
181+
} catch (IllegalAccessException exc) {
182+
throw new RuntimeException(
183+
String.format(
184+
"Malformed %s class %s: state declaration field %s is not accessible.",
185+
DoFn.class.getSimpleName(),
186+
target.getClass().getName(),
187+
stateDeclaration.field().getName()));
188+
}
189+
}
190+
191+
private static TimerSpec getTimerSpecOrCrash(
192+
TimerDeclaration timerDeclaration, DoFn<?, ?> target) {
193+
try {
194+
Object fieldValue = timerDeclaration.field().get(target);
195+
checkState(fieldValue instanceof TimerSpec,
196+
"Malformed %s class %s: timer declaration field %s does not have type %s.",
197+
DoFn.class.getSimpleName(),
198+
target.getClass().getName(),
199+
timerDeclaration.field().getName(),
200+
TimerSpec.class);
201+
202+
return (TimerSpec) timerDeclaration.field().get(target);
203+
} catch (IllegalAccessException exc) {
204+
throw new RuntimeException(
205+
String.format(
206+
"Malformed %s class %s: timer declaration field %s is not accessible.",
207+
DoFn.class.getSimpleName(),
208+
target.getClass().getName(),
209+
timerDeclaration.field().getName()));
210+
}
211+
}
212+
159213
public static DoFn<?, ?> getDoFn(ParDoPayload payload) throws InvalidProtocolBufferException {
160214
return doFnAndMainOutputTagFromProto(payload.getDoFn()).getDoFn();
161215
}
@@ -179,14 +233,149 @@ public static RunnerApi.PCollection getMainInput(
179233
return components.getPcollectionsOrThrow(ptransform.getInputsOrThrow(mainInputId));
180234
}
181235

182-
// TODO: Implement
183-
private static StateSpec toProto(StateDeclaration state) {
184-
throw new UnsupportedOperationException("Not yet supported");
236+
@VisibleForTesting
237+
static RunnerApi.StateSpec toProto(StateSpec<?> stateSpec, final SdkComponents components)
238+
throws IOException {
239+
final RunnerApi.StateSpec.Builder builder = RunnerApi.StateSpec.newBuilder();
240+
241+
return stateSpec.match(
242+
new StateSpec.Cases<RunnerApi.StateSpec>() {
243+
@Override
244+
public RunnerApi.StateSpec dispatchValue(Coder<?> valueCoder) {
245+
return builder
246+
.setValueSpec(
247+
RunnerApi.ValueStateSpec.newBuilder()
248+
.setCoderId(registerCoderOrThrow(components, valueCoder)))
249+
.build();
250+
}
251+
252+
@Override
253+
public RunnerApi.StateSpec dispatchBag(Coder<?> elementCoder) {
254+
return builder
255+
.setBagSpec(
256+
RunnerApi.BagStateSpec.newBuilder()
257+
.setElementCoderId(registerCoderOrThrow(components, elementCoder)))
258+
.build();
259+
}
260+
261+
@Override
262+
public RunnerApi.StateSpec dispatchCombining(
263+
Combine.CombineFn<?, ?, ?> combineFn, Coder<?> accumCoder) {
264+
return builder
265+
.setCombiningSpec(
266+
RunnerApi.CombiningStateSpec.newBuilder()
267+
.setAccumulatorCoderId(registerCoderOrThrow(components, accumCoder))
268+
.setCombineFn(CombineTranslation.toProto(combineFn)))
269+
.build();
270+
}
271+
272+
@Override
273+
public RunnerApi.StateSpec dispatchMap(Coder<?> keyCoder, Coder<?> valueCoder) {
274+
return builder
275+
.setMapSpec(
276+
RunnerApi.MapStateSpec.newBuilder()
277+
.setKeyCoderId(registerCoderOrThrow(components, keyCoder))
278+
.setValueCoderId(registerCoderOrThrow(components, valueCoder)))
279+
.build();
280+
}
281+
282+
@Override
283+
public RunnerApi.StateSpec dispatchSet(Coder<?> elementCoder) {
284+
return builder
285+
.setSetSpec(
286+
RunnerApi.SetStateSpec.newBuilder()
287+
.setElementCoderId(registerCoderOrThrow(components, elementCoder)))
288+
.build();
289+
}
290+
});
291+
}
292+
293+
@VisibleForTesting
294+
static StateSpec<?> fromProto(RunnerApi.StateSpec stateSpec, RunnerApi.Components components)
295+
throws IOException {
296+
switch (stateSpec.getSpecCase()) {
297+
case VALUE_SPEC:
298+
return StateSpecs.value(
299+
CoderTranslation.fromProto(
300+
components.getCodersMap().get(stateSpec.getValueSpec().getCoderId()), components));
301+
case BAG_SPEC:
302+
return StateSpecs.bag(
303+
CoderTranslation.fromProto(
304+
components.getCodersMap().get(stateSpec.getBagSpec().getElementCoderId()),
305+
components));
306+
case COMBINING_SPEC:
307+
FunctionSpec combineFnSpec = stateSpec.getCombiningSpec().getCombineFn().getSpec();
308+
309+
if (!combineFnSpec.getUrn().equals(CombineTranslation.JAVA_SERIALIZED_COMBINE_FN_URN)) {
310+
throw new UnsupportedOperationException(
311+
String.format(
312+
"Cannot create %s from non-Java %s: %s",
313+
StateSpec.class.getSimpleName(),
314+
Combine.CombineFn.class.getSimpleName(),
315+
combineFnSpec.getUrn()));
316+
}
317+
318+
Combine.CombineFn<?, ?, ?> combineFn =
319+
(Combine.CombineFn<?, ?, ?>)
320+
SerializableUtils.deserializeFromByteArray(
321+
combineFnSpec.getParameter().unpack(BytesValue.class).toByteArray(),
322+
Combine.CombineFn.class.getSimpleName());
323+
324+
// Rawtype coder cast because it is required to be a valid accumulator coder
325+
// for the CombineFn, by construction
326+
return StateSpecs.combining(
327+
(Coder)
328+
CoderTranslation.fromProto(
329+
components
330+
.getCodersMap()
331+
.get(stateSpec.getCombiningSpec().getAccumulatorCoderId()),
332+
components),
333+
combineFn);
334+
335+
case MAP_SPEC:
336+
return StateSpecs.map(
337+
CoderTranslation.fromProto(
338+
components.getCodersOrThrow(stateSpec.getMapSpec().getKeyCoderId()), components),
339+
CoderTranslation.fromProto(
340+
components.getCodersOrThrow(stateSpec.getMapSpec().getValueCoderId()), components));
341+
342+
case SET_SPEC:
343+
return StateSpecs.set(
344+
CoderTranslation.fromProto(
345+
components.getCodersMap().get(stateSpec.getSetSpec().getElementCoderId()),
346+
components));
347+
348+
case SPEC_NOT_SET:
349+
default:
350+
throw new IllegalArgumentException(
351+
String.format("Unknown %s: %s", RunnerApi.StateSpec.class.getName(), stateSpec));
352+
353+
}
354+
}
355+
356+
private static String registerCoderOrThrow(SdkComponents components, Coder coder) {
357+
try {
358+
return components.registerCoder(coder);
359+
} catch (IOException exc) {
360+
throw new RuntimeException("Failure to register coder", exc);
361+
}
185362
}
186363

187-
// TODO: Implement
188-
private static TimerSpec toProto(TimerDeclaration timer) {
189-
throw new UnsupportedOperationException("Not yet supported");
364+
private static RunnerApi.TimerSpec toProto(TimerSpec timer) {
365+
return RunnerApi.TimerSpec.newBuilder().setTimeDomain(toProto(timer.getTimeDomain())).build();
366+
}
367+
368+
private static RunnerApi.TimeDomain toProto(TimeDomain timeDomain) {
369+
switch(timeDomain) {
370+
case EVENT_TIME:
371+
return RunnerApi.TimeDomain.EVENT_TIME;
372+
case PROCESSING_TIME:
373+
return RunnerApi.TimeDomain.PROCESSING_TIME;
374+
case SYNCHRONIZED_PROCESSING_TIME:
375+
return RunnerApi.TimeDomain.SYNCHRONIZED_PROCESSING_TIME;
376+
default:
377+
throw new IllegalArgumentException("Unknown time domain");
378+
}
190379
}
191380

192381
@AutoValue

0 commit comments

Comments
 (0)