Skip to content

Commit

Permalink
[FLINK-37208][Runtime] Properly notify a new key is selected for asyn…
Browse files Browse the repository at this point in the history
…c state operators (#26068)
  • Loading branch information
Zakelly authored Jan 24, 2025
1 parent 19e868f commit 6a5fab8
Show file tree
Hide file tree
Showing 16 changed files with 172 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.flink.datastream.impl.common.OutputCollector;
import org.apache.flink.datastream.impl.common.TimestampCollector;
import org.apache.flink.datastream.impl.context.DefaultNonPartitionedContext;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;

import javax.annotation.Nullable;

Expand Down Expand Up @@ -83,10 +82,8 @@ protected NonPartitionedContext<OUT> getNonPartitionedContext() {
}

@Override
@SuppressWarnings({"rawtypes"})
public void setKeyContextElement1(StreamRecord record) throws Exception {
super.setKeyContextElement1(record);
keySet.add(getCurrentKey());
public void newKeySelected(Object newKey) {
keySet.add(newKey);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.flink.datastream.impl.common.OutputCollector;
import org.apache.flink.datastream.impl.common.TimestampCollector;
import org.apache.flink.datastream.impl.context.DefaultNonPartitionedContext;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;

import javax.annotation.Nullable;

Expand Down Expand Up @@ -89,17 +88,8 @@ protected NonPartitionedContext<OUT> getNonPartitionedContext() {
}

@Override
@SuppressWarnings({"rawtypes"})
public void setKeyContextElement1(StreamRecord record) throws Exception {
super.setKeyContextElement1(record);
keySet.add(getCurrentKey());
}

@Override
@SuppressWarnings({"rawtypes"})
public void setKeyContextElement2(StreamRecord record) throws Exception {
super.setKeyContextElement2(record);
keySet.add(getCurrentKey());
public void newKeySelected(Object newKey) {
keySet.add(newKey);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.flink.datastream.impl.common.OutputCollector;
import org.apache.flink.datastream.impl.common.TimestampCollector;
import org.apache.flink.datastream.impl.context.DefaultTwoOutputNonPartitionedContext;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.OutputTag;
import org.apache.flink.util.Preconditions;

Expand Down Expand Up @@ -105,10 +104,8 @@ protected TwoOutputNonPartitionedContext<OUT_MAIN, OUT_SIDE> getNonPartitionedCo
}

@Override
@SuppressWarnings({"rawtypes"})
public void setKeyContextElement1(StreamRecord record) throws Exception {
super.setKeyContextElement1(record);
keySet.add(getCurrentKey());
public void newKeySelected(Object newKey) {
keySet.add(newKey);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import org.apache.flink.streaming.api.operators.InternalTimer;
import org.apache.flink.streaming.api.operators.InternalTimerService;
import org.apache.flink.streaming.api.operators.Triggerable;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;

import javax.annotation.Nullable;

Expand Down Expand Up @@ -117,11 +116,8 @@ protected NonPartitionedContext<OUT> getNonPartitionedContext() {
}

@Override
@SuppressWarnings({"rawtypes"})
// Only element from input1 should be considered as the other side is broadcast input.
public void setKeyContextElement1(StreamRecord record) throws Exception {
super.setKeyContextElement1(record);
keySet.add(getCurrentKey());
public void newKeySelected(Object newKey) {
keySet.add(newKey);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import org.apache.flink.datastream.impl.operators.MockSumAggregateProcessFunction;
import org.apache.flink.streaming.api.operators.collect.utils.MockOperatorStateStore;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
import org.apache.flink.streaming.util.MockStreamingRuntimeContext;
import org.apache.flink.streaming.util.asyncprocessing.AsyncKeyedOneInputStreamOperatorTestHarness;

import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -98,8 +98,8 @@ void testListState() throws Exception {
new KeyedProcessOperator<>(
function, (KeySelector<Integer, Integer>) value -> value);

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
Expand All @@ -120,8 +120,8 @@ void testAggState() throws Exception {
KeyedProcessOperator<Integer, Integer, Integer> processOperator =
new KeyedProcessOperator<>(function);

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
Expand All @@ -148,8 +148,8 @@ void testValueState() throws Exception {
KeyedProcessOperator<Integer, Integer, Integer> processOperator =
new KeyedProcessOperator<>(function);

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
Expand All @@ -175,8 +175,8 @@ void testMapState() throws Exception {
KeyedProcessOperator<Integer, Integer, Integer> processOperator =
new KeyedProcessOperator<>(function);

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
Expand All @@ -203,8 +203,8 @@ void testReducingState() throws Exception {
KeyedProcessOperator<Integer, Integer, Integer> processOperator =
new KeyedProcessOperator<>(function);

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
Expand All @@ -231,8 +231,8 @@ void testBroadcastMapState() throws Exception {
KeyedProcessOperator<Integer, Integer, Integer> processOperator =
new KeyedProcessOperator<>(function);

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
Expand Down Expand Up @@ -260,8 +260,8 @@ void testBroadcastListState() throws Exception {
new KeyedProcessOperator<>(
function, (KeySelector<Integer, Integer>) value -> value);

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.apache.flink.datastream.api.context.PartitionedContext;
import org.apache.flink.datastream.api.function.OneInputStreamProcessFunction;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
import org.apache.flink.streaming.util.asyncprocessing.AsyncKeyedOneInputStreamOperatorTestHarness;

import org.junit.jupiter.api.Test;

Expand All @@ -51,8 +51,8 @@ public void processRecord(
}
});

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
Expand Down Expand Up @@ -98,8 +98,8 @@ public void endInput(NonPartitionedContext<Integer> ctx) {
}
});

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
Expand Down Expand Up @@ -133,8 +133,8 @@ public void processRecord(
// -1 is an invalid key in this suite.
(ignore) -> -1);

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.apache.flink.datastream.api.context.PartitionedContext;
import org.apache.flink.datastream.api.function.TwoInputBroadcastStreamProcessFunction;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness;
import org.apache.flink.streaming.util.asyncprocessing.AsyncKeyedTwoInputStreamOperatorTestHarness;

import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -61,8 +61,8 @@ public void processRecordFromBroadcastInput(
}
});

try (KeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
new KeyedTwoInputStreamOperatorTestHarness<>(
try (AsyncKeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
AsyncKeyedTwoInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Long>) (data) -> (long) (data + 1),
(KeySelector<Long, Long>) value -> value + 1,
Expand Down Expand Up @@ -130,11 +130,11 @@ public void endBroadcastInput(NonPartitionedContext<Long> ctx) {
}
});

try (KeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
new KeyedTwoInputStreamOperatorTestHarness<>(
try (AsyncKeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
AsyncKeyedTwoInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Long>) Long::valueOf,
(KeySelector<Long, Long>) value -> value,
null,
Types.LONG)) {
testHarness.open();
testHarness.processElement1(new StreamRecord<>(1)); // key is 1L
Expand Down Expand Up @@ -175,8 +175,8 @@ public void processRecordFromBroadcastInput(
// -1 is an invalid key in this suite.
(out) -> -1L);

try (KeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
new KeyedTwoInputStreamOperatorTestHarness<>(
try (AsyncKeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
AsyncKeyedTwoInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Long>) Long::valueOf,
(KeySelector<Long, Long>) value -> value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.apache.flink.datastream.api.context.PartitionedContext;
import org.apache.flink.datastream.api.function.TwoInputNonBroadcastStreamProcessFunction;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness;
import org.apache.flink.streaming.util.asyncprocessing.AsyncKeyedTwoInputStreamOperatorTestHarness;

import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -59,8 +59,8 @@ public void processRecordFromSecondInput(
}
});

try (KeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
new KeyedTwoInputStreamOperatorTestHarness<>(
try (AsyncKeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
AsyncKeyedTwoInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Long>) (data) -> (long) (data + 1),
(KeySelector<Long, Long>) value -> value + 1,
Expand Down Expand Up @@ -134,8 +134,8 @@ public void endSecondInput(NonPartitionedContext<Long> ctx) {
}
});

try (KeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
new KeyedTwoInputStreamOperatorTestHarness<>(
try (AsyncKeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
AsyncKeyedTwoInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Long>) Long::valueOf,
(KeySelector<Long, Long>) value -> value,
Expand Down Expand Up @@ -183,15 +183,23 @@ public void processRecordFromSecondInput(
// -1 is an invalid key in this suite.
(out) -> -1L);

try (KeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
new KeyedTwoInputStreamOperatorTestHarness<>(
try (AsyncKeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
AsyncKeyedTwoInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Long>) Long::valueOf,
(KeySelector<Long, Long>) value -> value,
Types.LONG)) {
testHarness.open();
assertThatThrownBy(() -> testHarness.processElement1(new StreamRecord<>(1)))
.isInstanceOf(IllegalStateException.class);
}
try (AsyncKeyedTwoInputStreamOperatorTestHarness<Long, Integer, Long, Long> testHarness =
AsyncKeyedTwoInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Long>) Long::valueOf,
(KeySelector<Long, Long>) value -> value,
Types.LONG)) {
testHarness.open();
assertThatThrownBy(() -> testHarness.processElement2(new StreamRecord<>(1L)))
.isInstanceOf(IllegalStateException.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.apache.flink.datastream.api.context.TwoOutputPartitionedContext;
import org.apache.flink.datastream.api.function.TwoOutputStreamProcessFunction;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
import org.apache.flink.streaming.util.asyncprocessing.AsyncKeyedOneInputStreamOperatorTestHarness;
import org.apache.flink.util.OutputTag;

import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -59,8 +59,8 @@ public void processRecord(
},
sideOutputTag);

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
Expand Down Expand Up @@ -116,8 +116,8 @@ public void endInput(
},
sideOutputTag);

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
Expand Down Expand Up @@ -161,14 +161,21 @@ public void processRecord(
// -1 is an invalid key in this suite.
(KeySelector<Long, Integer>) value -> -1);

try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
testHarness.open();
assertThatThrownBy(() -> testHarness.processElement(new StreamRecord<>(1)))
.isInstanceOf(IllegalStateException.class);
}
try (AsyncKeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
AsyncKeyedOneInputStreamOperatorTestHarness.create(
processOperator,
(KeySelector<Integer, Integer>) value -> value,
Types.INT)) {
testHarness.open();
emitToFirstOutput.set(false);
assertThatThrownBy(() -> testHarness.processElement(new StreamRecord<>(1)))
.isInstanceOf(IllegalStateException.class);
Expand Down
Loading

0 comments on commit 6a5fab8

Please sign in to comment.