Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.digitalpetri.modbus.tcp.ModbusTcpCodec;
import com.digitalpetri.netty.fsm.ChannelActions;
import com.digitalpetri.netty.fsm.ChannelFsm;
import com.digitalpetri.netty.fsm.ChannelFsm.TransitionListener;
import com.digitalpetri.netty.fsm.ChannelFsmConfig;
import com.digitalpetri.netty.fsm.ChannelFsmFactory;
import com.digitalpetri.netty.fsm.Event;
Expand All @@ -20,8 +21,10 @@
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import org.slf4j.Logger;
Expand All @@ -37,6 +40,8 @@ public class NettyTcpClientTransport implements ModbusTcpClientTransport {

private final AtomicReference<Consumer<ModbusTcpFrame>> frameReceiver = new AtomicReference<>();

private final List<ConnectionListener> connectionListeners = new CopyOnWriteArrayList<>();

private final ChannelFsm channelFsm;
private final ExecutionQueue executionQueue;

Expand All @@ -55,12 +60,31 @@ public NettyTcpClientTransport(NettyClientTransportConfig config) {
.build()
);

executionQueue = new ExecutionQueue(config.executor());

channelFsm.addTransitionListener(
(from, to, via) ->
logger.debug("onStateTransition: {} -> {} via {}", from, to, via)
(from, to, via) -> {
logger.debug("onStateTransition: {} -> {} via {}", from, to, via);

maybeNotifyConnectionListeners(from, to);
}
);
}

executionQueue = new ExecutionQueue(config.executor());
private void maybeNotifyConnectionListeners(State from, State to) {
if (connectionListeners.isEmpty()) {
return;
}

if (from != State.Connected && to == State.Connected) {
executionQueue.submit(() ->
connectionListeners.forEach(ConnectionListener::onConnection)
);
} else if (from == State.Connected && to != State.Connected) {
executionQueue.submit(() ->
connectionListeners.forEach(ConnectionListener::onConnectionLost)
);
}
}

@Override
Expand Down Expand Up @@ -100,6 +124,36 @@ public boolean isConnected() {
return channelFsm.getState() == State.Connected;
}

/**
* Get the {@link ChannelFsm} used by this transport.
*
* <p>This should not generally be used by client code except perhaps to add a
* {@link TransitionListener} to receive more detailed callbacks about the connection status.
*
* @return the {@link ChannelFsm} used by this transport.
*/
public ChannelFsm getChannelFsm() {
return channelFsm;
}

/**
* Add a {@link ConnectionListener} to this transport.
*
* @param listener the listener to add.
*/
public void addConnectionListener(ConnectionListener listener) {
connectionListeners.add(listener);
}

/**
* Remove a {@link ConnectionListener} from this transport.
*
* @param listener the listener to remove.
*/
public void removeConnectionListener(ConnectionListener listener) {
connectionListeners.remove(listener);
}

private class ModbusTcpFrameHandler extends SimpleChannelInboundHandler<ModbusTcpFrame> {

@Override
Expand Down Expand Up @@ -180,4 +234,21 @@ public static NettyTcpClientTransport create(
return new NettyTcpClientTransport(config);
}

public interface ConnectionListener {

/**
* Callback invoked when the transport has connected.
*/
void onConnection();

/**
* Callback invoked when the transport has disconnected.
*
* <p>Note that implementations do not need to initiate a reconnect, as this is handled
* automatically by {@link NettyTcpClientTransport}.
*/
void onConnectionLost();

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.digitalpetri.modbus.test;

import static org.junit.jupiter.api.Assertions.assertTrue;

import com.digitalpetri.modbus.ModbusPduSerializer.DefaultRequestSerializer;
import com.digitalpetri.modbus.client.ModbusClient;
import com.digitalpetri.modbus.client.ModbusTcpClient;
Expand All @@ -11,10 +13,13 @@
import com.digitalpetri.modbus.server.ReadWriteModbusServices;
import com.digitalpetri.modbus.tcp.Netty;
import com.digitalpetri.modbus.tcp.client.NettyTcpClientTransport;
import com.digitalpetri.modbus.tcp.client.NettyTcpClientTransport.ConnectionListener;
import com.digitalpetri.modbus.tcp.client.NettyTimeoutScheduler;
import com.digitalpetri.modbus.tcp.server.NettyTcpServerTransport;
import java.nio.ByteBuffer;
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -24,6 +29,8 @@ public class ModbusTcpClientServerIT extends ClientServerIT {
ModbusTcpClient client;
ModbusTcpServer server;

NettyTcpClientTransport clientTransport;

@BeforeEach
void setup() throws Exception {
var processImage = new ProcessImage();
Expand Down Expand Up @@ -59,7 +66,7 @@ protected Optional<ProcessImage> getProcessImage(int unitId) {
}

final var port = serverPort;
var clientTransport = NettyTcpClientTransport.create(
clientTransport = NettyTcpClientTransport.create(
cfg -> {
cfg.hostname = "localhost";
cfg.port = port;
Expand Down Expand Up @@ -112,4 +119,28 @@ void sendRaw() throws Exception {
System.out.println("responsePduBytes: " + Hex.format(responsePduBytes));
}

@Test
void connectionListener() throws Exception {
var onConnection = new CountDownLatch(1);
var onConnectionLost = new CountDownLatch(1);

clientTransport.addConnectionListener(new ConnectionListener() {
@Override
public void onConnection() {
onConnection.countDown();
}

@Override
public void onConnectionLost() {
onConnectionLost.countDown();
}
});

client.disconnect();
assertTrue(onConnectionLost.await(1, TimeUnit.SECONDS));

client.connect();
assertTrue(onConnection.await(1, TimeUnit.SECONDS));
}

}