diff --git a/mantis-network/src/main/java/io/reactivex/mantis/network/push/ChunkProcessor.java b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ChunkProcessor.java index bde053a5d..3f09d6e06 100644 --- a/mantis-network/src/main/java/io/reactivex/mantis/network/push/ChunkProcessor.java +++ b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ChunkProcessor.java @@ -19,16 +19,6 @@ import java.util.List; -public class ChunkProcessor { - - protected Router router; - - public ChunkProcessor(Router router) { - this.router = router; - } - - public void process(ConnectionManager connectionManager, List chunks) { - router.route(connectionManager.connections(), chunks); - } - +public interface ChunkProcessor { + void process(ConnectionManager connectionManager, List chunks); } diff --git a/mantis-network/src/main/java/io/reactivex/mantis/network/push/ConnectionGroup.java b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ConnectionGroup.java index 74d2c79d9..b65776305 100644 --- a/mantis-network/src/main/java/io/reactivex/mantis/network/push/ConnectionGroup.java +++ b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ConnectionGroup.java @@ -23,11 +23,14 @@ import io.mantisrx.common.metrics.Metrics; import io.mantisrx.common.metrics.spectator.GaugeCallback; import io.mantisrx.common.metrics.spectator.MetricGroupId; + import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import rx.functions.Func0; @@ -43,10 +46,12 @@ public class ConnectionGroup { private Counter successfulWrites; private Counter numSlotSwaps; private Counter failedWrites; + private final Optional> routerO; - public ConnectionGroup(String groupId) { + public ConnectionGroup(String groupId, Optional> routerO) { this.groupId = groupId; this.connections = new HashMap<>(); + this.routerO = routerO; final String grpId = Optional.ofNullable(groupId).orElse("none"); final BasicTag groupIdTag = new BasicTag(MantisMetricStringConstants.GROUP_ID_TAG, grpId); @@ -93,6 +98,7 @@ public synchronized void removeConnection(AsyncConnection connection) { + " a new connection has already been swapped in the place of the old connection"); } + this.routerO.ifPresent(router -> router.removeConnection(connection)); } public synchronized void addConnection(AsyncConnection connection) { @@ -107,6 +113,7 @@ public synchronized void addConnection(AsyncConnection connection) { previousConnection.close(); numSlotSwaps.increment(); } + this.routerO.ifPresent(router -> router.addConnection(connection)); } public synchronized boolean isEmpty() { @@ -132,4 +139,11 @@ public String toString() { return "ConnectionGroup [groupId=" + groupId + ", connections=" + connections + "]"; } + + public void route(List chunks, Router fallbackRouter) { + this.routerO.ifPresentOrElse( + router -> router.route(chunks), + () -> fallbackRouter.route(this.getConnections(), chunks) + ); + } } diff --git a/mantis-network/src/main/java/io/reactivex/mantis/network/push/ConnectionManager.java b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ConnectionManager.java index 0d14fe320..f1158bb7a 100644 --- a/mantis-network/src/main/java/io/reactivex/mantis/network/push/ConnectionManager.java +++ b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ConnectionManager.java @@ -20,9 +20,11 @@ import io.mantisrx.common.metrics.MetricsRegistry; import io.mantisrx.common.metrics.spectator.GaugeCallback; import io.mantisrx.common.metrics.spectator.MetricGroupId; + import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; +import java.util.Optional; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; @@ -32,6 +34,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import rx.functions.Action0; +import rx.functions.Func1; public class ConnectionManager { @@ -46,9 +49,13 @@ public class ConnectionManager { private Action0 doOnZeroConnections; private Lock connectionState = new ReentrantLock(); private AtomicBoolean subscribed = new AtomicBoolean(); + private final Func1>> routerFactory; public ConnectionManager(MetricsRegistry metricsRegistry, - Action0 doOnFirstConnection, Action0 doOnZeroConnections) { + Action0 doOnFirstConnection, + Action0 doOnZeroConnections, + Func1>> routerFactory) { + this.routerFactory = routerFactory; this.doOnFirstConnection = doOnFirstConnection; this.doOnZeroConnections = doOnZeroConnections; this.metricsRegistry = metricsRegistry; @@ -119,11 +126,13 @@ protected void add(AsyncConnection connection) { String groupId = connection.getGroupId(); ConnectionGroup current = managedConnections.get(groupId); if (current == null) { - ConnectionGroup newGroup = new ConnectionGroup(groupId); + Optional> groupRouter = routerFactory.call(groupId); + ConnectionGroup newGroup = new ConnectionGroup(groupId, groupRouter); current = managedConnections.putIfAbsent(groupId, newGroup); if (current == null) { current = newGroup; metricsRegistry.registerAndGet(current.getMetrics()); + groupRouter.ifPresent(router -> metricsRegistry.registerAndGet(router.getMetrics())); } } current.addConnection(connection); @@ -167,19 +176,6 @@ protected void remove(AsyncConnection connection) { } } - public Set> connections() { - connectionState.lock(); - try { - Set> connections = new HashSet<>(); - for (ConnectionGroup group : managedConnections.values()) { - connections.addAll(group.getConnections()); - } - return connections; - } finally { - connectionState.unlock(); - } - } - public Map> groups() { connectionState.lock(); try { diff --git a/mantis-network/src/main/java/io/reactivex/mantis/network/push/GroupChunkProcessor.java b/mantis-network/src/main/java/io/reactivex/mantis/network/push/GroupChunkProcessor.java index 8216424e5..f9586388e 100644 --- a/mantis-network/src/main/java/io/reactivex/mantis/network/push/GroupChunkProcessor.java +++ b/mantis-network/src/main/java/io/reactivex/mantis/network/push/GroupChunkProcessor.java @@ -20,18 +20,18 @@ import java.util.Map; -public class GroupChunkProcessor extends ChunkProcessor { +public class GroupChunkProcessor implements ChunkProcessor { + protected Router fallbackRouter; - public GroupChunkProcessor(Router router) { - super(router); + public GroupChunkProcessor(Router fallbackRouter) { + this.fallbackRouter = fallbackRouter; } @Override public void process(ConnectionManager connectionManager, List chunks) { Map> groups = connectionManager.groups(); for (ConnectionGroup group : groups.values()) { - router.route(group.getConnections(), chunks); + group.route(chunks, fallbackRouter); } } - } diff --git a/mantis-network/src/main/java/io/reactivex/mantis/network/push/ProactiveConsistentHashingRouter.java b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ProactiveConsistentHashingRouter.java new file mode 100644 index 000000000..58d19daf4 --- /dev/null +++ b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ProactiveConsistentHashingRouter.java @@ -0,0 +1,171 @@ +/* + * Copyright 2025 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.reactivex.mantis.network.push; + +import com.netflix.spectator.api.Tag; +import io.mantisrx.common.metrics.Counter; +import io.mantisrx.common.metrics.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import rx.functions.Func1; + +import java.util.*; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +public class ProactiveConsistentHashingRouter implements ProactiveRouter> { + private static final Logger logger = LoggerFactory.getLogger(ProactiveConsistentHashingRouter.class); + private final int connectionRepetitionOnRing; + + protected final Func1, byte[]> encoder; + protected final Counter numEventsRouted; + protected final Counter numEventsProcessed; + protected final Counter numConnectionUpdates; + protected final Metrics metrics; + private final HashFunction hashFunction; + private final NavigableMap>> ring = new TreeMap<>(); + private final ReadWriteLock ringLock = new ReentrantReadWriteLock(); + + public ProactiveConsistentHashingRouter( + String name, + Func1, byte[]> dataEncoder, + HashFunction hashFunction) { + this(name, dataEncoder, hashFunction, 1000); + } + + public ProactiveConsistentHashingRouter( + String name, + Func1, byte[]> dataEncoder, + HashFunction hashFunction, + int ringRepetitionPerConnection) { + this.connectionRepetitionOnRing = ringRepetitionPerConnection; + this.encoder = dataEncoder; + metrics = new Metrics.Builder() + .id("Router_" + name, Tag.of("router_type", "proactive_consistent_hashing")) + .addCounter("numEventsRouted") + .addCounter("numEventsProcessed") + .addCounter("numConnectionUpdates") + .build(); + numEventsRouted = metrics.getCounter("numEventsRouted"); + numEventsProcessed = metrics.getCounter("numEventsProcessed"); + numConnectionUpdates = metrics.getCounter("numConnectionUpdates"); + this.hashFunction = hashFunction; + } + + @Override + public void route(List> chunks) { + if (chunks == null || chunks.isEmpty()) { + return; + } + numEventsProcessed.increment(chunks.size()); + + // Read lock only for ring access + Map>, List> writes; + ringLock.readLock().lock(); + try { + if (ring.isEmpty()) { + return; + } + + int numConnections = ring.size() / connectionRepetitionOnRing; + int bufferCapacity = (chunks.size() / numConnections) + 1; // assume even distribution + writes = new HashMap<>(numConnections); + + // process chunks (ring access inside lookupConnection) + for (KeyValuePair kvp : chunks) { + long hash = kvp.getKeyBytesHashed(); + // lookup slot + Map.Entry>> connectionEntry = ring.ceilingEntry(hash); + AsyncConnection> connection = (connectionEntry == null ? ring.firstEntry() : connectionEntry).getValue(); + // add to writes + Func1, Boolean> predicate = connection.getPredicate(); + if (predicate == null || predicate.call(kvp)) { + List buffer = writes.computeIfAbsent(connection, k -> new ArrayList<>(bufferCapacity)); + buffer.add(encoder.call(kvp)); + } + } + } finally { + ringLock.readLock().unlock(); + } + + // process writes (outside lock - no ring access) + if (!writes.isEmpty()) { + for (Map.Entry>, List> entry : writes.entrySet()) { + AsyncConnection> connection = entry.getKey(); + List toWrite = entry.getValue(); + connection.write(toWrite); + numEventsRouted.increment(toWrite.size()); + } + } + } + + @Override + public void addConnection(AsyncConnection> connection) { + String connectionId = connection.getSlotId(); + if (connectionId == null) { + throw new IllegalStateException("Connection must specify an id for consistent hashing"); + } + + List hashCollisions = new ArrayList<>(); + ringLock.writeLock().lock(); + try { + for (int i = 0; i < connectionRepetitionOnRing; i++) { + // hash node on ring + byte[] connectionBytes = (connectionId + "-" + i).getBytes(); + long hash = hashFunction.computeHash(connectionBytes); + if (ring.containsKey(hash)) { + hashCollisions.add(connectionId + "-" + i); + } + ring.put(hash, connection); + } + } finally { + ringLock.writeLock().unlock(); + } + numConnectionUpdates.increment(); + + // Log outside lock + if (!hashCollisions.isEmpty()) { + logger.error("Hash collisions detected when adding connection {}: {}", connectionId, hashCollisions); + } + } + + @Override + public void removeConnection(AsyncConnection> connection) { + String connectionId = connection.getSlotId(); + if (connectionId == null) { + throw new IllegalStateException("Connection must specify an id for consistent hashing"); + } + + ringLock.writeLock().lock(); + try { + for (int i = 0; i < connectionRepetitionOnRing; i++) { + // hash node on ring + byte[] connectionBytes = (connectionId + "-" + i).getBytes(); + long hash = hashFunction.computeHash(connectionBytes); + ring.remove(hash); + } + } finally { + ringLock.writeLock().unlock(); + } + numConnectionUpdates.increment(); + } + + @Override + public Metrics getMetrics() { + return metrics; + } +} diff --git a/mantis-network/src/main/java/io/reactivex/mantis/network/push/ProactiveRoundRobinRouter.java b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ProactiveRoundRobinRouter.java new file mode 100644 index 000000000..377bd6aa0 --- /dev/null +++ b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ProactiveRoundRobinRouter.java @@ -0,0 +1,93 @@ +/* + * Copyright 2025 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.reactivex.mantis.network.push; + +import com.netflix.spectator.api.Tag; +import io.mantisrx.common.metrics.Counter; +import io.mantisrx.common.metrics.Metrics; +import rx.functions.Func1; + +import java.util.*; + +public class ProactiveRoundRobinRouter implements ProactiveRouter { + private final List> connections = new ArrayList<>(); + private int currentIndex = 0; + + protected Func1 encoder; + protected final Counter numEventsRouted; + protected final Counter numEventsProcessed; + protected final Counter numConnectionUpdates; + protected final Metrics metrics; + + public ProactiveRoundRobinRouter(String name, Func1 encoder) { + this.encoder = encoder; + metrics = new Metrics.Builder() + .id("Router_" + name, Tag.of("router_type", "proactive_round_robin")) + .addCounter("numEventsRouted") + .addCounter("numEventsProcessed") + .addCounter("numConnectionUpdates") + .build(); + numEventsRouted = metrics.getCounter("numEventsRouted"); + numEventsProcessed = metrics.getCounter("numEventsProcessed"); + numConnectionUpdates = metrics.getCounter("numConnectionUpdates"); + } + + @Override + public synchronized void addConnection(AsyncConnection connection) { + // We do not need to shuffle because we are constantly looping through + numConnectionUpdates.increment(); + connections.add(connection); + } + + @Override + public synchronized void removeConnection(AsyncConnection connection) { + numConnectionUpdates.increment(); + connections.remove(connection); + } + + @Override + public synchronized void route(List chunks) { + if (connections.isEmpty() || chunks == null || chunks.isEmpty()) { + return; + } + numEventsProcessed.increment(chunks.size()); + Map, List> writes = new HashMap<>(); + int arrayListSize = chunks.size() / connections.size() + 1; // assume even distribution + // process chunks + for (T chunk : chunks) { + currentIndex = currentIndex % connections.size(); + AsyncConnection connection = connections.get(currentIndex); + Func1 predicate = connection.getPredicate(); + if (predicate == null || predicate.call(chunk)) { + List buffer = writes.computeIfAbsent(connection, k -> new ArrayList<>(arrayListSize)); + buffer.add(encoder.call(chunk)); + currentIndex++; + } + } + for (Map.Entry, List> entry : writes.entrySet()) { + AsyncConnection connection = entry.getKey(); + List toWrite = entry.getValue(); + connection.write(toWrite); + numEventsRouted.increment(toWrite.size()); + } + } + + @Override + public Metrics getMetrics() { + return metrics; + } +} diff --git a/mantis-network/src/main/java/io/reactivex/mantis/network/push/ProactiveRouter.java b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ProactiveRouter.java new file mode 100644 index 000000000..18f72d92c --- /dev/null +++ b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ProactiveRouter.java @@ -0,0 +1,15 @@ +package io.reactivex.mantis.network.push; + +import io.mantisrx.common.metrics.Metrics; + +import java.util.List; + +public interface ProactiveRouter { + void route(List chunks); + + void addConnection(AsyncConnection connection); + + void removeConnection(AsyncConnection connection); + + Metrics getMetrics(); +} diff --git a/mantis-network/src/main/java/io/reactivex/mantis/network/push/PushServer.java b/mantis-network/src/main/java/io/reactivex/mantis/network/push/PushServer.java index c9ba2326f..a82a0d363 100644 --- a/mantis-network/src/main/java/io/reactivex/mantis/network/push/PushServer.java +++ b/mantis-network/src/main/java/io/reactivex/mantis/network/push/PushServer.java @@ -108,7 +108,7 @@ public void call() { final MetricGroupId metricsGroup = new MetricGroupId("PushServer", idTag); // manager will auto add metrics for connection groups connectionManager = new ConnectionManager(metricsRegistry, doOnFirstConnection, - doOnZeroConnections); + doOnZeroConnections, config.getRouterFactory()); int numQueueProcessingThreads = config.getNumQueueConsumers(); @@ -154,8 +154,7 @@ public void call() { processedWrites = serverMetrics.getCounter("numProcessedWrites"); registerMetrics(metricsRegistry, serverMetrics, consumerThreads.getMetrics(), - outboundBuffer.getMetrics(), trigger.getMetrics(), - config.getChunkProcessor().router.getMetrics()); + outboundBuffer.getMetrics(), trigger.getMetrics()); port = config.getPort(); writeRetryCount = config.getWriteRetryCount(); @@ -165,14 +164,12 @@ public void call() { private void registerMetrics(MetricsRegistry registry, Metrics serverMetrics, Metrics consumerPoolMetrics, Metrics queueMetrics, - Metrics pushTriggerMetrics, - Metrics routerMetrics) { + Metrics pushTriggerMetrics) { registry.registerAndGet(serverMetrics); registry.registerAndGet(consumerPoolMetrics); registry.registerAndGet(queueMetrics); registry.registerAndGet(pushTriggerMetrics); - registry.registerAndGet(routerMetrics); } protected Observable manageConnection(final DefaultChannelWriter writer, String host, int port, diff --git a/mantis-network/src/main/java/io/reactivex/mantis/network/push/RouterFactory.java b/mantis-network/src/main/java/io/reactivex/mantis/network/push/RouterFactory.java index 6fc97c9e8..8166ff349 100644 --- a/mantis-network/src/main/java/io/reactivex/mantis/network/push/RouterFactory.java +++ b/mantis-network/src/main/java/io/reactivex/mantis/network/push/RouterFactory.java @@ -2,6 +2,37 @@ import rx.functions.Func1; +import java.nio.ByteBuffer; + public interface RouterFactory { - public Router scalarStageToStageRouter(String name, final Func1 toBytes); + Router scalarStageToStageRouter(String name, final Func1 toBytes); + + default ProactiveRouter scalarStageToStageProactiveRouter(String name, final Func1 toBytes) { + return new ProactiveRoundRobinRouter<>(name, toBytes); + } + + default Router> keyedRouter(String name, Func1 keyEncoder, Func1 valueEncoder) { + return new ConsistentHashingRouter<>(name, RouterFactory.consistentHashingEncoder(valueEncoder), HashFunctions.xxh3()); + } + + default ProactiveRouter> keyedProactiveRouter(String name, Func1 keyEncoder, Func1 valueEncoder) { + return new ProactiveConsistentHashingRouter<>(name, RouterFactory.consistentHashingEncoder(valueEncoder), HashFunctions.xxh3()); + } + + public static Func1, byte[]> consistentHashingEncoder(final Func1 valueEncoder) { + return kvp -> { + byte[] keyBytes = kvp.getKeyBytes(); + byte[] valueBytes = valueEncoder.call(kvp.getValue()); + return + // length + opcode + notification type + key length + ByteBuffer.allocate(2 * Integer.BYTES + 2 * Byte.BYTES + keyBytes.length + valueBytes.length) + .putInt(2 * Byte.BYTES + Integer.BYTES + keyBytes.length + valueBytes.length) // length + .put((byte) 1) // opcode + .put((byte) 1) // notification type + .putInt(keyBytes.length) // key length + .put(keyBytes) // key bytes + .put(valueBytes) // value bytes + .array(); + }; + } } diff --git a/mantis-network/src/main/java/io/reactivex/mantis/network/push/Routers.java b/mantis-network/src/main/java/io/reactivex/mantis/network/push/Routers.java index 7e8592fba..4238a627a 100644 --- a/mantis-network/src/main/java/io/reactivex/mantis/network/push/Routers.java +++ b/mantis-network/src/main/java/io/reactivex/mantis/network/push/Routers.java @@ -34,23 +34,7 @@ public Routers() {} public static Router> consistentHashingLegacyTcpProtocol(String name, final Func1 keyEncoder, final Func1 valueEncoder) { - return new ConsistentHashingRouter(name, new Func1, byte[]>() { - @Override - public byte[] call(KeyValuePair kvp) { - byte[] keyBytes = kvp.getKeyBytes(); - byte[] valueBytes = valueEncoder.call(kvp.getValue()); - return - // length + opcode + notification type + key length - ByteBuffer.allocate(4 + 1 + 1 + 4 + keyBytes.length + valueBytes.length) - .putInt(1 + 1 + 4 + keyBytes.length + valueBytes.length) // length - .put((byte) 1) // opcode - .put((byte) 1) // notification type - .putInt(keyBytes.length) // key length - .put(keyBytes) // key bytes - .put(valueBytes) // value bytes - .array(); - } - }, HashFunctions.xxh3()); + return new ConsistentHashingRouter(name, RouterFactory.consistentHashingEncoder(valueEncoder), HashFunctions.xxh3()); } private static byte[] dataPayload(byte[] data) { diff --git a/mantis-network/src/main/java/io/reactivex/mantis/network/push/ServerConfig.java b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ServerConfig.java index 1c3088dda..d234614d5 100644 --- a/mantis-network/src/main/java/io/reactivex/mantis/network/push/ServerConfig.java +++ b/mantis-network/src/main/java/io/reactivex/mantis/network/push/ServerConfig.java @@ -19,6 +19,8 @@ import io.mantisrx.common.metrics.MetricsRegistry; import java.util.List; import java.util.Map; +import java.util.Optional; + import rx.functions.Func1; @@ -36,6 +38,7 @@ public class ServerConfig { private MetricsRegistry metricsRegistry; // registry used to store metrics private Func1>, Func1> predicate; private boolean useSpscQueue = false; + private final Func1>> routerFactory; public ServerConfig(Builder builder) { this.name = builder.name; @@ -50,6 +53,7 @@ public ServerConfig(Builder builder) { this.predicate = builder.predicate; this.useSpscQueue = builder.useSpscQueue; this.maxNotWritableTimeSec = builder.maxNotWritableTimeSec; + this.routerFactory = builder.routerFactory; } public Func1>, Func1> getPredicate() { @@ -100,6 +104,10 @@ public boolean useSpscQueue() { return useSpscQueue; } + public Func1>> getRouterFactory() { + return this.routerFactory; + } + public static class Builder { private String name; @@ -114,6 +122,7 @@ public static class Builder { private MetricsRegistry metricsRegistry; // registry used to store metrics private Func1>, Func1> predicate; private boolean useSpscQueue = false; + private Func1>> routerFactory = (String groupId) -> Optional.empty(); public Builder predicate(Func1>, Func1> predicate) { this.predicate = predicate; @@ -170,13 +179,13 @@ public Builder groupRouter(Router router) { return this; } - public Builder router(Router router) { - this.chunkProcessor = new ChunkProcessor<>(router); + public Builder metricsRegistry(MetricsRegistry metricsRegistry) { + this.metricsRegistry = metricsRegistry; return this; } - public Builder metricsRegistry(MetricsRegistry metricsRegistry) { - this.metricsRegistry = metricsRegistry; + public Builder proactiveRouterFactory(Func1>> routerFactory) { + this.routerFactory = routerFactory; return this; } diff --git a/mantis-network/src/test/java/io/reactivex/mantis/network/push/ProactiveConsistentHashingRouterTest.java b/mantis-network/src/test/java/io/reactivex/mantis/network/push/ProactiveConsistentHashingRouterTest.java new file mode 100644 index 000000000..01e0216b4 --- /dev/null +++ b/mantis-network/src/test/java/io/reactivex/mantis/network/push/ProactiveConsistentHashingRouterTest.java @@ -0,0 +1,204 @@ +/* + * Copyright 2025 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.reactivex.mantis.network.push; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import rx.subjects.PublishSubject; + +import java.util.*; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ProactiveConsistentHashingRouterTest { + + private ProactiveConsistentHashingRouter router; + private HashFunction hashFunction; + + @BeforeEach + public void setup() { + hashFunction = HashFunctions.xxh3(); + router = new ProactiveConsistentHashingRouter<>("test-router", + kvp -> kvp.getValue().getBytes(), hashFunction); + } + + @Test + public void testRouteWithNoConnections() { + List> data = createTestData("key1", "value1"); + + // Should not throw exception + assertDoesNotThrow(() -> router.route(data)); + } + + @Test + public void testRouteWithNullData() { + TestAsyncConnection connection = createConnection("slot-1"); + router.addConnection(connection); + + // Should not throw exception + assertDoesNotThrow(() -> router.route(null)); + assertEquals(0, connection.getWrittenData().size()); + } + + @Test + public void testRouteWithEmptyData() { + TestAsyncConnection connection = createConnection("slot-1"); + router.addConnection(connection); + + router.route(Collections.emptyList()); + + assertEquals(0, connection.getWrittenData().size()); + } + + @Test + public void testConsistentHashing() { + // Add multiple connections + TestAsyncConnection connection1 = createConnection("slot-1"); + TestAsyncConnection connection2 = createConnection("slot-2"); + TestAsyncConnection connection3 = createConnection("slot-3"); + + router.addConnection(connection1); + router.addConnection(connection2); + router.addConnection(connection3); + + // Route the same key multiple times - should always go to same connection + String key = "consistent-key"; + for (int i = 0; i < 5; i++) { + List> data = createTestData(key, "value" + i); + router.route(data); + } + assertEquals(5, connection1.getWrittenData().size()); + assertEquals(0, connection2.getWrittenData().size()); + assertEquals(0, connection3.getWrittenData().size()); + } + + @Test + public void testDistributionAcrossMultipleConnections() { + // Add multiple connections + TestAsyncConnection connection1 = createConnection("slot-1"); + TestAsyncConnection connection2 = createConnection("slot-2"); + TestAsyncConnection connection3 = createConnection("slot-3"); + + router.addConnection(connection1); + router.addConnection(connection2); + router.addConnection(connection3); + + // Route many different keys + List> data = new ArrayList<>(); + int routed = 30000; + for (int i = 0; i < routed; i++) { + String key = "key-" + i; + long hash = hashFunction.computeHash(key.getBytes()); + data.add(new KeyValuePair<>(hash, key.getBytes(), "value-" + i)); + } + router.route(data); + + int actualRouted = connection1.getWrittenData().size() + + connection2.getWrittenData().size() + + connection3.getWrittenData().size(); + + assertEquals(routed, actualRouted); + // roughly even distribution, but allow 2% variance + assertEquals(10000, connection1.getWrittenData().size(), routed / 50.0); + assertEquals(10000, connection2.getWrittenData().size(), routed / 50.0); + assertEquals(10000, connection3.getWrittenData().size(), routed / 50.0); + } + + @Test + public void testRouteWithPredicate() { + // Create connection with predicate that filters out certain values + TestAsyncConnection connection = createConnection("slot-1", + kvp -> kvp.getValue().startsWith("accept")); + + router.addConnection(connection); + + // Route data - some should be filtered out + List> data = new ArrayList<>(); + data.add(createKeyValuePair("key1", "accept-value1")); + data.add(createKeyValuePair("key2", "reject-value2")); + data.add(createKeyValuePair("key3", "accept-value3")); + + router.route(data); + + // Only 2 values should have been routed + assertEquals(2, connection.getWrittenData().size()); + } + + @Test + public void testAddConnectionWithNullSlotId() { + PublishSubject> subject = PublishSubject.create(); + AsyncConnection> connection = + new AsyncConnection<>("host", 1234, "id1", null, "group1", subject, null); + + assertThrows(IllegalStateException.class, () -> router.addConnection(connection)); + } + + @Test + public void testRemoveConnectionWithNullSlotId() { + PublishSubject> subject = PublishSubject.create(); + AsyncConnection> connection = + new AsyncConnection<>("host", 1234, "id1", null, "group1", subject, null); + + assertThrows(IllegalStateException.class, () -> router.removeConnection(connection)); + } + + // Helper methods + + private TestAsyncConnection createConnection(String slotId) { + return createConnection(slotId, null); + } + + private TestAsyncConnection createConnection(String slotId, + rx.functions.Func1, Boolean> predicate) { + return new TestAsyncConnection(slotId, predicate); + } + + private List> createTestData(String key, String value) { + List> data = new ArrayList<>(); + data.add(createKeyValuePair(key, value)); + return data; + } + + private KeyValuePair createKeyValuePair(String key, String value) { + long hash = hashFunction.computeHash(key.getBytes()); + return new KeyValuePair<>(hash, key.getBytes(), value); + } + + // Test helper class + private static class TestAsyncConnection extends AsyncConnection> { + private final List writtenData = Collections.synchronizedList(new ArrayList<>()); + private int writeCalls = 0; + + public TestAsyncConnection(String slotId, + rx.functions.Func1, Boolean> predicate) { + super("test-host", 1234, "id-" + slotId, slotId, "test-group", + PublishSubject.create(), predicate); + } + + @Override + public synchronized void write(List data) { + writeCalls++; + writtenData.addAll(data); + } + + public synchronized List getWrittenData() { + return new ArrayList<>(writtenData); + } + } +} diff --git a/mantis-network/src/test/java/io/reactivex/mantis/network/push/ProactiveRoundRobinRouterTest.java b/mantis-network/src/test/java/io/reactivex/mantis/network/push/ProactiveRoundRobinRouterTest.java new file mode 100644 index 000000000..ba78ba6d6 --- /dev/null +++ b/mantis-network/src/test/java/io/reactivex/mantis/network/push/ProactiveRoundRobinRouterTest.java @@ -0,0 +1,184 @@ +/* + * Copyright 2025 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.reactivex.mantis.network.push; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import rx.functions.Func1; +import rx.subjects.PublishSubject; + +import java.util.*; + +import static org.junit.jupiter.api.Assertions.*; + +public class ProactiveRoundRobinRouterTest { + + private ProactiveRoundRobinRouter router; + + @BeforeEach + public void setup() { + router = new ProactiveRoundRobinRouter<>("test-router", String::getBytes); + } + + @Test + public void testRouteWithNoConnections() { + List data = Arrays.asList("test-value"); + + assertDoesNotThrow(() -> router.route(data)); + } + + @Test + public void testRouteWithNullData() { + TestAsyncConnection connection = createConnection("slot-1"); + router.addConnection(connection); + + // Should not throw exception + assertDoesNotThrow(() -> router.route(null)); + assertEquals(0, connection.getWrittenData().size()); + } + + @Test + public void testRouteWithEmptyData() { + TestAsyncConnection connection = createConnection("slot-1"); + router.addConnection(connection); + + router.route(Collections.emptyList()); + + assertEquals(0, connection.getWrittenData().size()); + } + + @Test + public void testRoundRobinDistribution() { + // Add 3 connections + TestAsyncConnection connection1 = createConnection("slot-1"); + TestAsyncConnection connection2 = createConnection("slot-2"); + TestAsyncConnection connection3 = createConnection("slot-3"); + + router.addConnection(connection1); + router.addConnection(connection2); + router.addConnection(connection3); + + // Route 9 items - should be evenly distributed + List data = new ArrayList<>(); + for (int i = 0; i < 9; i++) { + data.add("value-" + i); + } + router.route(data); + + // Each connection should receive 3 items + assertEquals(3, connection1.getWrittenData().size()); + assertEquals(3, connection2.getWrittenData().size()); + assertEquals(3, connection3.getWrittenData().size()); + } + + @Test + public void testPredicateAcrossMultipleConnections() { + // Create connections with different predicates + Func1 predicate = data -> data.contains("even"); + TestAsyncConnection connection1 = createConnection("slot-1", predicate); + TestAsyncConnection connection2 = createConnection("slot-2", predicate); + + router.addConnection(connection1); + router.addConnection(connection2); + + List data = Arrays.asList("value-even-0", "value-odd-1", "value-even-2", "value-odd-3"); + router.route(data); + + assertEquals(1, connection1.getWrittenData().size()); + assertEquals(1, connection2.getWrittenData().size()); + } + + @Test + public void testConnectionsAddedAfterRouting() { + TestAsyncConnection connection1 = createConnection("slot-1"); + router.addConnection(connection1); + + // Route some data + router.route(Arrays.asList("value1", "value2")); + assertEquals(2, connection1.getWrittenData().size()); + + // Add another connection + TestAsyncConnection connection2 = createConnection("slot-2"); + router.addConnection(connection2); + + // Route more data - should now distribute across both + router.route(Arrays.asList("value3", "value4", "value5", "value6")); + + assertEquals(2, connection2.getWrittenData().size()); + } + + @Test + public void testRemoveConnectionDuringRouting() { + TestAsyncConnection connection1 = createConnection("slot-1"); + TestAsyncConnection connection2 = createConnection("slot-2"); + TestAsyncConnection connection3 = createConnection("slot-3"); + + router.addConnection(connection1); + router.addConnection(connection2); + router.addConnection(connection3); + + // Route some data + router.route(Arrays.asList("value1", "value2", "value3")); + + // Remove middle connection + router.removeConnection(connection2); + + // Route more data + router.route(Arrays.asList("value4", "value5", "value6", "value7")); + + // Connection2 should not have received any data after removal + assertEquals(3, connection1.getWrittenData().size()); + assertEquals(1, connection2.getWrittenData().size()); // Only from first route + assertEquals(3, connection3.getWrittenData().size()); + } + + // Helper methods + + private TestAsyncConnection createConnection(String slotId) { + return createConnection(slotId, null); + } + + private TestAsyncConnection createConnection(String slotId, + rx.functions.Func1 predicate) { + return new TestAsyncConnection(slotId, predicate); + } + + // Test helper class + private static class TestAsyncConnection extends AsyncConnection { + private final List writtenData = new ArrayList<>(); + private int writeCalls = 0; + + public TestAsyncConnection(String slotId, rx.functions.Func1 predicate) { + super("test-host", 1234, "id-" + slotId, slotId, "test-group", + PublishSubject.create(), predicate); + } + + @Override + public void write(List data) { + writeCalls++; + writtenData.addAll(data); + } + + public List getWrittenData() { + return writtenData; + } + + public int getWriteCalls() { + return writeCalls; + } + } +} diff --git a/mantis-network/src/test/java/io/reactivex/mantis/network/push/TimedChunkerTest.java b/mantis-network/src/test/java/io/reactivex/mantis/network/push/TimedChunkerTest.java index 3654f156d..3757a287f 100644 --- a/mantis-network/src/test/java/io/reactivex/mantis/network/push/TimedChunkerTest.java +++ b/mantis-network/src/test/java/io/reactivex/mantis/network/push/TimedChunkerTest.java @@ -156,13 +156,12 @@ public void testLongProcessing() throws Exception { assertEquals(expected, processor.getProcessed()); } - public static class TestProcessor extends ChunkProcessor { + public static class TestProcessor implements ChunkProcessor { private ScheduledExecutorService scheduledService = Executors.newSingleThreadScheduledExecutor(); private List processed = new ArrayList(); private long processingTimeMs = 0; public TestProcessor(long processingTimeMs) { - super(null); this.processingTimeMs = processingTimeMs; } diff --git a/mantis-runtime/src/main/java/io/mantisrx/runtime/GroupToScalar.java b/mantis-runtime/src/main/java/io/mantisrx/runtime/GroupToScalar.java index cfb4cf35a..2ffc63ae8 100644 --- a/mantis-runtime/src/main/java/io/mantisrx/runtime/GroupToScalar.java +++ b/mantis-runtime/src/main/java/io/mantisrx/runtime/GroupToScalar.java @@ -56,7 +56,7 @@ public class GroupToScalar extends StageConfig { GroupToScalar(GroupToScalarComputation computation, Config config, Codec inputKeyCodec, Codec inputCodec) { - super(config.description, inputKeyCodec, inputCodec, config.codec, config.inputStrategy, config.parameters, config.concurrency); + super(config.description, inputKeyCodec, inputCodec, config.codec, config.inputStrategy, config.parameters, config.concurrency, config.useProactiveRouter); this.computation = computation; this.keyExpireTimeSeconds = config.keyExpireTimeSeconds; } @@ -80,6 +80,7 @@ public static class Config { private INPUT_STRATEGY inputStrategy = INPUT_STRATEGY.SERIAL; private int concurrency = DEFAULT_STAGE_CONCURRENCY; private List> parameters = Collections.emptyList(); + private boolean useProactiveRouter = false; /** * @param codec is netty reactivex Codec @@ -126,6 +127,11 @@ public Config concurrentInput(final int concurrency) { return this; } + public Config shouldUseProactiveRouter(boolean useProactiveRouter) { + this.useProactiveRouter = useProactiveRouter; + return this; + } + public Codec getCodec() { return codec; } @@ -149,6 +155,9 @@ public Config withParameters(List> params) { return this; } + public boolean getUseProactiveRouter() { + return useProactiveRouter; + } } } diff --git a/mantis-runtime/src/main/java/io/mantisrx/runtime/KeyToKey.java b/mantis-runtime/src/main/java/io/mantisrx/runtime/KeyToKey.java index 5760934cd..9bd614995 100644 --- a/mantis-runtime/src/main/java/io/mantisrx/runtime/KeyToKey.java +++ b/mantis-runtime/src/main/java/io/mantisrx/runtime/KeyToKey.java @@ -79,6 +79,7 @@ public static class Config { // do not allow config to override private final INPUT_STRATEGY inputStrategy = INPUT_STRATEGY.SERIAL; private List> parameters = Collections.emptyList(); + private boolean useProactiveRouter = false; /** * @param codec is a netty reactivex codec @@ -136,6 +137,16 @@ public Config withParameters(List> params) this.parameters = params; return this; } + + /** + * Configure this stage to use proactive routers for better connection management performance. + * + * @return this config for method chaining + */ + public Config withProactiveRouter() { + this.useProactiveRouter = true; + return this; + } } } diff --git a/mantis-runtime/src/main/java/io/mantisrx/runtime/KeyValueStageConfig.java b/mantis-runtime/src/main/java/io/mantisrx/runtime/KeyValueStageConfig.java index 4e0cd6907..aa0a447bc 100644 --- a/mantis-runtime/src/main/java/io/mantisrx/runtime/KeyValueStageConfig.java +++ b/mantis-runtime/src/main/java/io/mantisrx/runtime/KeyValueStageConfig.java @@ -33,11 +33,15 @@ public abstract class KeyValueStageConfig extends StageConfig { private final Codec keyCodec; public KeyValueStageConfig(String description, Codec inputKeyCodec, Codec inputCodec, Codec outputKeyCodec, Codec outputCodec, INPUT_STRATEGY inputStrategy, List> params) { - this(description, inputKeyCodec, inputCodec, outputKeyCodec, outputCodec, inputStrategy, params, DEFAULT_STAGE_CONCURRENCY); + this(description, inputKeyCodec, inputCodec, outputKeyCodec, outputCodec, inputStrategy, params, false); } - public KeyValueStageConfig(String description, Codec inputKeyCodec, Codec inputCodec, Codec outputKeyCodec, Codec outputCodec, INPUT_STRATEGY inputStrategy, List> params, int concurrency) { - super(description, inputKeyCodec, inputCodec, outputCodec, inputStrategy, params, concurrency); + public KeyValueStageConfig(String description, Codec inputKeyCodec, Codec inputCodec, Codec outputKeyCodec, Codec outputCodec, INPUT_STRATEGY inputStrategy, List> params, boolean useProactiveRouter) { + this(description, inputKeyCodec, inputCodec, outputKeyCodec, outputCodec, inputStrategy, params, DEFAULT_STAGE_CONCURRENCY, useProactiveRouter); + } + + public KeyValueStageConfig(String description, Codec inputKeyCodec, Codec inputCodec, Codec outputKeyCodec, Codec outputCodec, INPUT_STRATEGY inputStrategy, List> params, int concurrency, boolean useProactiveRouter) { + super(description, inputKeyCodec, inputCodec, outputCodec, inputStrategy, params, concurrency, useProactiveRouter); this.keyCodec = outputKeyCodec; } diff --git a/mantis-runtime/src/main/java/io/mantisrx/runtime/ScalarToGroup.java b/mantis-runtime/src/main/java/io/mantisrx/runtime/ScalarToGroup.java index 7fb2a3206..9368c73d8 100644 --- a/mantis-runtime/src/main/java/io/mantisrx/runtime/ScalarToGroup.java +++ b/mantis-runtime/src/main/java/io/mantisrx/runtime/ScalarToGroup.java @@ -54,7 +54,7 @@ public class ScalarToGroup extends KeyValueStageConfig { public ScalarToGroup(ToGroupComputation computation, Config config, Codec inputCodec) { - super(config.description, null, inputCodec, config.keyCodec, config.codec, config.inputStrategy, config.parameters, config.concurrency); + super(config.description, null, inputCodec, config.keyCodec, config.codec, config.inputStrategy, config.parameters, config.concurrency, config.useProactiveRouter); this.computation = computation; this.keyExpireTimeSeconds = config.keyExpireTimeSeconds; @@ -79,6 +79,7 @@ public static class Config { private int concurrency = DEFAULT_STAGE_CONCURRENCY; private long keyExpireTimeSeconds = Long.MAX_VALUE; // never expire by default private List> parameters = Collections.emptyList(); + private boolean useProactiveRouter = false; /** * @param codec is Codec of netty reactivex @@ -129,6 +130,11 @@ public Config description(String description) { return this; } + public Config useProactiveRouter(boolean useProactiveRouter) { + this.useProactiveRouter = useProactiveRouter; + return this; + } + public Codec getCodec() { return codec; } @@ -155,5 +161,9 @@ public Config withParameters(List> params) { this.parameters = params; return this; } + + public boolean isUseProactiveRouter() { + return useProactiveRouter; + } } } diff --git a/mantis-runtime/src/main/java/io/mantisrx/runtime/ScalarToScalar.java b/mantis-runtime/src/main/java/io/mantisrx/runtime/ScalarToScalar.java index 7ba4309e2..5cc374a80 100644 --- a/mantis-runtime/src/main/java/io/mantisrx/runtime/ScalarToScalar.java +++ b/mantis-runtime/src/main/java/io/mantisrx/runtime/ScalarToScalar.java @@ -34,15 +34,13 @@ public class ScalarToScalar extends StageConfig { */ ScalarToScalar(ScalarComputation computation, Config config, final io.reactivex.netty.codec.Codec inputCodec) { - super(config.description, NettyCodec.fromNetty(inputCodec), config.codec, config.inputStrategy, config.parameters, config.concurrency); - this.computation = computation; - this.inputStrategy = config.inputStrategy; + this(computation, config, NettyCodec.fromNetty(inputCodec)); } public ScalarToScalar(ScalarComputation computation, Config config, Codec inputCodec) { - super(config.description, inputCodec, config.codec, config.inputStrategy, config.parameters, config.concurrency); + super(config.description, inputCodec, config.codec, config.inputStrategy, config.parameters, config.concurrency, config.useProactiveRouter); this.computation = computation; this.inputStrategy = config.inputStrategy; this.parameters = config.parameters; @@ -68,6 +66,7 @@ public static class Config { // default input type is serial for 'collecting' use case private INPUT_STRATEGY inputStrategy = INPUT_STRATEGY.SERIAL; private volatile int concurrency = StageConfig.DEFAULT_STAGE_CONCURRENCY; + private boolean useProactiveRouter = false; private List> parameters = Collections.emptyList(); @@ -118,6 +117,16 @@ public Config withParameters(List> params) { return this; } + /** + * Configure this stage to use proactive routers for better connection management performance. + * + * @return this config for method chaining + */ + public Config withProactiveRouter() { + this.useProactiveRouter = true; + return this; + } + public String getDescription() { return description; } diff --git a/mantis-runtime/src/main/java/io/mantisrx/runtime/StageConfig.java b/mantis-runtime/src/main/java/io/mantisrx/runtime/StageConfig.java index a920c1987..5eeb796dc 100644 --- a/mantis-runtime/src/main/java/io/mantisrx/runtime/StageConfig.java +++ b/mantis-runtime/src/main/java/io/mantisrx/runtime/StageConfig.java @@ -40,6 +40,9 @@ public abstract class StageConfig { // number of inner observables processed private int concurrency = DEFAULT_STAGE_CONCURRENCY; + // determines whether to use proactive routers for better connection management performance + private boolean useProactiveRouter = false; + public StageConfig(String description, Codec inputCodec, Codec outputCodec, INPUT_STRATEGY inputStrategy) { this(description, inputCodec, outputCodec, inputStrategy, Collections.emptyList(), DEFAULT_STAGE_CONCURRENCY); @@ -52,7 +55,7 @@ public StageConfig(String description, Codec inputCodec, public StageConfig(String description, Codec inputKeyCodec, Codec inputCodec, Codec outputCodec, INPUT_STRATEGY inputStrategy, List> params) { - this(description, inputKeyCodec, inputCodec, outputCodec, inputStrategy, params, DEFAULT_STAGE_CONCURRENCY); + this(description, inputKeyCodec, inputCodec, outputCodec, inputStrategy, params, DEFAULT_STAGE_CONCURRENCY, false); } public StageConfig(String description, Codec inputCodec, @@ -63,12 +66,18 @@ public StageConfig(String description, Codec inputCodec, public StageConfig(String description, Codec inputCodec, Codec outputCodec, INPUT_STRATEGY inputStrategy, List> params, int concurrency) { - this(description, null, inputCodec, outputCodec, inputStrategy, params, concurrency); + this(description, inputCodec, outputCodec, inputStrategy, params, concurrency, false); + } + + public StageConfig(String description, Codec inputCodec, + Codec outputCodec, INPUT_STRATEGY inputStrategy, List> params, + int concurrency, boolean useProactiveRouter) { + this(description, null, inputCodec, outputCodec, inputStrategy, params, concurrency, useProactiveRouter); } public StageConfig(String description, Codec inputKeyCodec, Codec inputCodec, Codec outputCodec, INPUT_STRATEGY inputStrategy, List> params, - int concurrency) { + int concurrency, boolean useProactiveRouter) { this.description = description; this.inputKeyCodec = inputKeyCodec; this.inputCodec = inputCodec; @@ -76,6 +85,7 @@ public StageConfig(String description, Codec inputKeyCodec, Codec inpu this.inputStrategy = inputStrategy; this.parameters = params; this.concurrency = concurrency; + this.useProactiveRouter = useProactiveRouter; } public String getDescription() { @@ -109,5 +119,12 @@ public int getConcurrency() { return concurrency; } + /** + * @return true if proactive routers should be used for this stage + */ + public boolean shouldUseProactiveRouter() { + return useProactiveRouter; + } + public enum INPUT_STRATEGY {NONE_SPECIFIED, SERIAL, CONCURRENT} } diff --git a/mantis-runtime/src/main/java/io/mantisrx/runtime/executor/WorkerPublisherRemoteObservable.java b/mantis-runtime/src/main/java/io/mantisrx/runtime/executor/WorkerPublisherRemoteObservable.java index 73092c926..86739dd8c 100644 --- a/mantis-runtime/src/main/java/io/mantisrx/runtime/executor/WorkerPublisherRemoteObservable.java +++ b/mantis-runtime/src/main/java/io/mantisrx/runtime/executor/WorkerPublisherRemoteObservable.java @@ -39,6 +39,8 @@ import rx.Observable; import rx.functions.Func1; +import java.util.Optional; + /** * Execution of WorkerPublisher that publishes the stream to the next stage. * @@ -86,12 +88,17 @@ public void start(final StageConfig stage, Observable> toSer Func1 encoder = t1 -> stage.getOutputCodec().encode(t1); Router router = this.routerFactory.scalarStageToStageRouter(name, encoder); + Func1>> proactiveFactory = (String k) -> Optional.empty(); + if (stage.shouldUseProactiveRouter()) { + proactiveFactory = (String name) -> Optional.of(routerFactory.scalarStageToStageProactiveRouter(name, encoder)); + } ServerConfig config = new ServerConfig.Builder() .name(name) .port(serverPort) .metricsRegistry(MetricsRegistry.getInstance()) - .router(router) + .groupRouter(router) + .proactiveRouterFactory(proactiveFactory) .build(); final LegacyTcpPushServer modernServer = PushServers.infiniteStreamLegacyTcpNested(config, toServe); @@ -135,6 +142,12 @@ private LegacyTcpPushServer> startKeyValueStage(KeyValueS Func1 valueEncoder = t1 -> stage.getOutputCodec().encode(t1); Func1 keyEncoder = t1 -> stage.getOutputKeyCodec().encode(t1); + Router> router = this.routerFactory.keyedRouter(name, keyEncoder, valueEncoder); + Func1>>> proactiveFactory = (String k) -> Optional.empty(); + if (stage.shouldUseProactiveRouter()) { + proactiveFactory = (String name) -> Optional.of(routerFactory.keyedProactiveRouter(name, keyEncoder, valueEncoder)); + } + ServerConfig> config = new ServerConfig.Builder>() .name(name) .port(serverPort) @@ -144,7 +157,8 @@ private LegacyTcpPushServer> startKeyValueStage(KeyValueS .maxChunkTimeMSec(maxChunkTimeMSec()) .bufferCapacity(bufferCapacity()) .useSpscQueue(useSpsc()) - .router(Routers.consistentHashingLegacyTcpProtocol(jobName, keyEncoder, valueEncoder)) + .groupRouter(router) + .proactiveRouterFactory(proactiveFactory) .build(); if (stage instanceof ScalarToGroup || stage instanceof GroupToGroup) { @@ -161,7 +175,6 @@ private LegacyTcpPushServer> startKeyValueStage(KeyValueS private boolean useSpsc() { String stringValue = propService.getStringValue("mantis.w2w.spsc", "false"); return Boolean.parseBoolean(stringValue); - } private int bufferCapacity() { diff --git a/mantis-runtime/src/main/java/io/mantisrx/runtime/sink/ServerSentEventsSink.java b/mantis-runtime/src/main/java/io/mantisrx/runtime/sink/ServerSentEventsSink.java index 71c36504a..f71fb849d 100644 --- a/mantis-runtime/src/main/java/io/mantisrx/runtime/sink/ServerSentEventsSink.java +++ b/mantis-runtime/src/main/java/io/mantisrx/runtime/sink/ServerSentEventsSink.java @@ -25,14 +25,18 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelOption; import io.netty.channel.WriteBufferWaterMark; +import io.reactivex.mantis.network.push.ProactiveRouter; import io.reactivex.mantis.network.push.PushServerSse; import io.reactivex.mantis.network.push.PushServers; import io.reactivex.mantis.network.push.Routers; import io.reactivex.mantis.network.push.ServerConfig; import io.reactivex.mantis.network.push.Router; + import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.Optional; + import mantis.io.reactivex.netty.RxNetty; import mantis.io.reactivex.netty.pipeline.PipelineConfigurators; import mantis.io.reactivex.netty.protocol.http.server.HttpServer; @@ -58,6 +62,7 @@ public class ServerSentEventsSink implements SelfDocumentingSink { private int port = -1; private final MantisPropertiesLoader propService; private final Router router; + private Func1>> proactiveRouterFactory = (String routerName) -> Optional.empty(); private PushServerSse pushServerSse; private HttpServer httpServer; @@ -90,6 +95,7 @@ public ServerSentEventsSink(Func1 encoder) { this.subscribeProcessor = builder.subscribeProcessor; this.propService = ServiceRegistry.INSTANCE.getPropertiesService(); this.router = builder.router; + this.proactiveRouterFactory = builder.proactiveRouterFactory; } @Override @@ -165,7 +171,8 @@ public void call(Context context, PortRequest portRequest, final Observable o .numQueueConsumers(numConsumerThreads()) .useSpscQueue(useSpsc()) .maxChunkTimeMSec(getBatchInterval()) - .maxNotWritableTimeSec(maxNotWritableTimeSec()); + .maxNotWritableTimeSec(maxNotWritableTimeSec()) + .proactiveRouterFactory(proactiveRouterFactory); if (predicate != null) { config.predicate(predicate.getPredicate()); } @@ -254,6 +261,7 @@ public static class Builder { private Predicate predicate; private Func2>, Context, Void> subscribeProcessor; private Router router; + private Func1>> proactiveRouterFactory = (String routerName) -> Optional.empty(); public Builder withEncoder(Func1 encoder) { this.encoder = encoder; @@ -291,6 +299,11 @@ public Builder withRouter(Router router) { return this; } + public Builder withProactiveRouterFactory(Func1>> proactiveRouterFactory) { + this.proactiveRouterFactory = proactiveRouterFactory; + return this; + } + public ServerSentEventsSink build() { return new ServerSentEventsSink<>(this); }