Last active
March 12, 2025 16:27
-
-
Save mrsimpson/7dc9c3f2a15eddb65dd3516dca3f28c6 to your computer and use it in GitHub Desktop.
Controllable Source Function for Flink integration testing testing a complete pipeline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package de.db.lightgate.streaming.functions; | |
| import java.util.ArrayList; | |
| import java.util.List; | |
| import java.util.concurrent.CompletableFuture; | |
| import java.util.concurrent.ConcurrentHashMap; | |
| import java.util.concurrent.ConcurrentMap; | |
| import java.util.concurrent.CountDownLatch; | |
| import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; | |
| import org.apache.flink.streaming.api.watermark.Watermark; | |
| /** | |
| * A Source function that be explicitly started and awaited for. This allows for using it in a <a | |
| * href="https://ci.apache.org/projects/flink/flink-docs-stable/dev/stream/testing.html#integration-testing">Flink | |
| * integration test </a> | |
| * | |
| * @param <T> Type of the elements to be collected | |
| * @see <a href="https://stackoverflow.com/a/54924081/6141374">Original source on Stack Overflow by | |
| * Till Rohrmann which also depicts how to write the actual test</a>, Licence CC BY-SA 4.0 | |
| */ | |
| public class ControllableSourceFunction<T> extends RichParallelSourceFunction<T> { | |
| private static final ConcurrentMap<String, CountDownLatch> startLatches = | |
| new ConcurrentHashMap<>(); | |
| private static final ConcurrentMap<String, CompletableFuture<Void>> finishedFutures = | |
| new ConcurrentHashMap<>(); | |
| private final String name; | |
| private final ArrayList<T> streamElements; | |
| private final int keepAliveMs; | |
| private volatile boolean running; | |
| public ControllableSourceFunction(String name, List<T> streamElements, int keepAliveMs) { | |
| this.name = name; | |
| this.streamElements = new ArrayList<>(streamElements); | |
| this.keepAliveMs = keepAliveMs; | |
| } | |
| public ControllableSourceFunction(String name, List<T> streamElements) { | |
| this(name, streamElements, 0); | |
| } | |
| public static void startExecution(ControllableSourceFunction source, int index) { | |
| final CountDownLatch startLatch = | |
| startLatches.computeIfAbsent(source.getId(index), ignored -> new CountDownLatch(1)); | |
| startLatch.countDown(); | |
| } | |
| public static CompletableFuture<Void> getFinishedFuture( | |
| ControllableSourceFunction source, int index) { | |
| return finishedFutures.computeIfAbsent( | |
| source.getId(index), ignored -> new CompletableFuture<>()); | |
| } | |
| @Override | |
| public void run(SourceContext<T> ctx) throws Exception { | |
| final int index = getRuntimeContext().getIndexOfThisSubtask(); | |
| final CountDownLatch startLatch = | |
| startLatches.computeIfAbsent(getId(index), ignored -> new CountDownLatch(1)); | |
| final CompletableFuture<Void> finishedFuture = | |
| finishedFutures.computeIfAbsent(getId(index), ignored -> new CompletableFuture<>()); | |
| this.running = true; | |
| startLatch.await(); | |
| int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask(); | |
| for (int i = indexOfThisSubtask; | |
| i < this.streamElements.size(); | |
| i += getRuntimeContext().getNumberOfParallelSubtasks()) { | |
| T streamElement = this.streamElements.get(i); | |
| if (!running) break; | |
| synchronized (ctx.getCheckpointLock()) { | |
| collect(ctx, streamElement); | |
| } | |
| } | |
| ctx.emitWatermark(new Watermark(Long.MAX_VALUE)); | |
| Thread.sleep(this.keepAliveMs); | |
| finishedFuture.complete(null); | |
| } | |
| /** | |
| * Override this method if either emitting with timestamp is necessary or if you run this as a | |
| * real parallel function: Probably, an identifier which is part of the key needs to be unique | |
| * across tasks . Else, duplicates would be created. Use `getIndexOfThisSubtask()` to generate | |
| * parallelized data | |
| */ | |
| protected void collect(SourceContext<T> ctx, T streamElement) { | |
| ctx.collect(streamElement); | |
| } | |
| @Override | |
| public void cancel() { | |
| running = false; | |
| } | |
| private String getId(int index) { | |
| return name + '_' + index; | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import java.util.ArrayList; | |
| import java.util.Collections; | |
| import java.util.List; | |
| import java.util.concurrent.CompletableFuture; | |
| import org.apache.flink.runtime.jobgraph.JobGraph; | |
| import org.apache.flink.runtime.jobmaster.JobResult; | |
| import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration; | |
| import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; | |
| import org.apache.flink.streaming.api.functions.sink.SinkFunction; | |
| import org.apache.flink.test.util.MiniClusterWithClientResource; | |
| import org.apache.flink.util.concurrent.FutureUtils; | |
| import org.junit.ClassRule; | |
| import org.junit.Test; | |
| /** | |
| * A skeleton for an integration test which instantiates the whole flink application | |
| * injecting a spy for a sink. | |
| */ | |
| public class SampleIntegrationTest { | |
| private static final int NUM_TMS = 1; | |
| private static final int NUM_SLOTS = 1; | |
| private static final int PARALLELISM = NUM_SLOTS * NUM_TMS; | |
| /** | |
| * A Dummy pojo for the sake of syntax | |
| */ | |
| static class MyPojo {} | |
| static final class SampleInputSource extends ControllableSourceFunction<MyPojo> { | |
| SampleInputSource(List<MyPojo> streamElements) { | |
| super("sample-input-source", streamElements); | |
| } | |
| @Override | |
| public void cancel() {} | |
| } | |
| /** | |
| * A simple sink which can be spied on | |
| * Caution: the list holding the collected records needs to be static so that the elements collected when running the pipeline | |
| * can be asserted in the test after the pipeline has run. | |
| * Keep this class inside each test so that the is not shared across tests. | |
| */ | |
| protected static class CollectSink<T> implements SinkFunction<T> { | |
| public static final List<T> values = Collections.synchronizedList(new ArrayList<>()); | |
| @Override | |
| public void invoke(T value, SinkFunction.Context context) throws Exception { | |
| values.add(value); | |
| } | |
| } | |
| @RegisterExtension | |
| public static final MiniClusterExtension MINI_CLUSTER_EXTENSION = | |
| new MiniClusterExtension( | |
| new MiniClusterResourceConfiguration.Builder() | |
| .setNumberSlotsPerTaskManager(NUM_SLOTS) | |
| .setNumberTaskManagers(NUM_TMS) | |
| .build()); | |
| protected void runPipeline(CollectSink mockSink) throws Exception { | |
| final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); | |
| SampleInputSource sampleInputSource1 = new SampleInputSource(List.of(new MyPojo())); | |
| SampleInputSource sampleInputSource2 = new SampleInputSource(List.of(new MyPojo())); | |
| env.setParallelism(PARALLELISM); | |
| MyFlinkApp app = | |
| new MyFlinkApp( | |
| env, | |
| env.addSource(sampleInputSource1), | |
| env.addSource(sampleInputSource2), | |
| mockSink); | |
| app.run(); //submits the job definition creating the graph | |
| final JobGraph jobGraph = env.getStreamGraph().getJobGraph(); | |
| MINI_CLUSTER_WITH_CLIENT_RESOURCE.getMiniCluster().submitJob(jobGraph).get(); | |
| final CompletableFuture<JobResult> jobResultFuture = | |
| MINI_CLUSTER_WITH_CLIENT_RESOURCE.getMiniCluster().requestJobResult(jobGraph.getJobID()); | |
| final ArrayList<CompletableFuture<Void>> finishedFutures = new ArrayList<>(PARALLELISM); | |
| // trigger the first source and wait for it to be emitted completely | |
| ControllableSourceFunction.startExecution(sampleInputSource1, 0); | |
| finishedFutures.add(ControllableSourceFunction.getFinishedFuture(sampleInputSource1, 0)); | |
| FutureUtils.waitForAll(finishedFutures).join(); | |
| // and the second source afterwards. For illustration purposes, this is parallelized here | |
| for (int i = 0; i < PARALLELISM; i++) { | |
| ControllableSourceFunction.startExecution(sampleInputSource2, i); | |
| finishedFutures.add(ControllableSourceFunction.getFinishedFuture(sampleInputSource2, i)); | |
| } | |
| jobResultFuture.join(); | |
| } | |
| @Test | |
| public void sampleTest() throws Exception { | |
| CollectSink mockSink = new CollectSink(); | |
| CollectSink.values.clear(); // it's a static buffer, so we need to clear it actively (best @BeforeEach) | |
| runPipeline(mockSink); | |
| assertThat(CollectSink.values).hasSize(1); // some sample assertion | |
| } | |
| } |
Author
@kkrugler thanks so much for looking at my code and suggesting improvements 👍 .
It took a while ™️ , but today I needed the code again, so I
- added that each task only emits the n-th element
- added an option to keep the source alive after emitting the last element. This was necessary when using a real sink (in my case MQTT) in order to emit all elements (and not kill the thread prior to completion)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Some comments...
getIndexOfThisSubtask()element, otherwise you'll have duplicates when parallelism > 1.runningflag is marked as volatile. I usually like setting this to true inside of therun()method.