diff --git a/modbus-tcp/src/main/java/com/digitalpetri/modbus/tcp/client/NettyTcpClientTransport.java b/modbus-tcp/src/main/java/com/digitalpetri/modbus/tcp/client/NettyTcpClientTransport.java index 58e0c94..72fbacd 100644 --- a/modbus-tcp/src/main/java/com/digitalpetri/modbus/tcp/client/NettyTcpClientTransport.java +++ b/modbus-tcp/src/main/java/com/digitalpetri/modbus/tcp/client/NettyTcpClientTransport.java @@ -7,6 +7,7 @@ import com.digitalpetri.modbus.tcp.ModbusTcpCodec; import com.digitalpetri.netty.fsm.ChannelActions; import com.digitalpetri.netty.fsm.ChannelFsm; +import com.digitalpetri.netty.fsm.ChannelFsm.TransitionListener; import com.digitalpetri.netty.fsm.ChannelFsmConfig; import com.digitalpetri.netty.fsm.ChannelFsmFactory; import com.digitalpetri.netty.fsm.Event; @@ -20,8 +21,10 @@ import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; +import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import org.slf4j.Logger; @@ -37,6 +40,8 @@ public class NettyTcpClientTransport implements ModbusTcpClientTransport { private final AtomicReference> frameReceiver = new AtomicReference<>(); + private final List connectionListeners = new CopyOnWriteArrayList<>(); + private final ChannelFsm channelFsm; private final ExecutionQueue executionQueue; @@ -55,12 +60,31 @@ public NettyTcpClientTransport(NettyClientTransportConfig config) { .build() ); + executionQueue = new ExecutionQueue(config.executor()); + channelFsm.addTransitionListener( - (from, to, via) -> - logger.debug("onStateTransition: {} -> {} via {}", from, to, via) + (from, to, via) -> { + logger.debug("onStateTransition: {} -> {} via {}", from, to, via); + + maybeNotifyConnectionListeners(from, to); + } ); + } - executionQueue = new ExecutionQueue(config.executor()); + private void maybeNotifyConnectionListeners(State from, State to) { + if (connectionListeners.isEmpty()) { + return; + } + + if (from != State.Connected && to == State.Connected) { + executionQueue.submit(() -> + connectionListeners.forEach(ConnectionListener::onConnection) + ); + } else if (from == State.Connected && to != State.Connected) { + executionQueue.submit(() -> + connectionListeners.forEach(ConnectionListener::onConnectionLost) + ); + } } @Override @@ -100,6 +124,36 @@ public boolean isConnected() { return channelFsm.getState() == State.Connected; } + /** + * Get the {@link ChannelFsm} used by this transport. + * + *

This should not generally be used by client code except perhaps to add a + * {@link TransitionListener} to receive more detailed callbacks about the connection status. + * + * @return the {@link ChannelFsm} used by this transport. + */ + public ChannelFsm getChannelFsm() { + return channelFsm; + } + + /** + * Add a {@link ConnectionListener} to this transport. + * + * @param listener the listener to add. + */ + public void addConnectionListener(ConnectionListener listener) { + connectionListeners.add(listener); + } + + /** + * Remove a {@link ConnectionListener} from this transport. + * + * @param listener the listener to remove. + */ + public void removeConnectionListener(ConnectionListener listener) { + connectionListeners.remove(listener); + } + private class ModbusTcpFrameHandler extends SimpleChannelInboundHandler { @Override @@ -180,4 +234,21 @@ public static NettyTcpClientTransport create( return new NettyTcpClientTransport(config); } + public interface ConnectionListener { + + /** + * Callback invoked when the transport has connected. + */ + void onConnection(); + + /** + * Callback invoked when the transport has disconnected. + * + *

Note that implementations do not need to initiate a reconnect, as this is handled + * automatically by {@link NettyTcpClientTransport}. + */ + void onConnectionLost(); + + } + } diff --git a/modbus-tests/src/test/java/com/digitalpetri/modbus/test/ModbusTcpClientServerIT.java b/modbus-tests/src/test/java/com/digitalpetri/modbus/test/ModbusTcpClientServerIT.java index c2a2a4b..5cda666 100644 --- a/modbus-tests/src/test/java/com/digitalpetri/modbus/test/ModbusTcpClientServerIT.java +++ b/modbus-tests/src/test/java/com/digitalpetri/modbus/test/ModbusTcpClientServerIT.java @@ -1,5 +1,7 @@ package com.digitalpetri.modbus.test; +import static org.junit.jupiter.api.Assertions.assertTrue; + import com.digitalpetri.modbus.ModbusPduSerializer.DefaultRequestSerializer; import com.digitalpetri.modbus.client.ModbusClient; import com.digitalpetri.modbus.client.ModbusTcpClient; @@ -11,10 +13,13 @@ import com.digitalpetri.modbus.server.ReadWriteModbusServices; import com.digitalpetri.modbus.tcp.Netty; import com.digitalpetri.modbus.tcp.client.NettyTcpClientTransport; +import com.digitalpetri.modbus.tcp.client.NettyTcpClientTransport.ConnectionListener; import com.digitalpetri.modbus.tcp.client.NettyTimeoutScheduler; import com.digitalpetri.modbus.tcp.server.NettyTcpServerTransport; import java.nio.ByteBuffer; import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -24,6 +29,8 @@ public class ModbusTcpClientServerIT extends ClientServerIT { ModbusTcpClient client; ModbusTcpServer server; + NettyTcpClientTransport clientTransport; + @BeforeEach void setup() throws Exception { var processImage = new ProcessImage(); @@ -59,7 +66,7 @@ protected Optional getProcessImage(int unitId) { } final var port = serverPort; - var clientTransport = NettyTcpClientTransport.create( + clientTransport = NettyTcpClientTransport.create( cfg -> { cfg.hostname = "localhost"; cfg.port = port; @@ -112,4 +119,30 @@ void sendRaw() throws Exception { System.out.println("responsePduBytes: " + Hex.format(responsePduBytes)); } + @Test + void connectionListener() throws Exception { + var onConnection = new CountDownLatch(1); + var onConnectionLost = new CountDownLatch(1); + + clientTransport.addConnectionListener(new ConnectionListener() { + @Override + public void onConnection() { + onConnection.countDown(); + } + + @Override + public void onConnectionLost() { + onConnectionLost.countDown(); + } + }); + + assertTrue(client.isConnected()); + + client.disconnect(); + assertTrue(onConnectionLost.await(1, TimeUnit.SECONDS)); + + client.connect(); + assertTrue(onConnection.await(1, TimeUnit.SECONDS)); + } + }