Skip to content
Merged
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,8 +1121,14 @@ def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
# DP attention. Use TCP + port to handle both single-node and multi-node.
if server_args.nnodes == 1 and server_args.dist_init_addr is None:
dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
elif server_args.dist_init_addr.startswith("["): # ipv6 address

port, host = PortArgs.configure_ipv6(server_args)

dist_init_addr = (host, str(port))
else:
dist_init_addr = server_args.dist_init_addr.split(":")

assert (
len(dist_init_addr) == 2
), "please provide --dist-init-addr as host:port of head node"
Expand Down
32 changes: 32 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,6 +1516,38 @@ def is_valid_ipv6_address(address: str) -> bool:
return False


def configure_ipv6(dist_init_addr):
addr = server_args.dist_init_addr
end = addr.find("]")
if end == -1:
raise ValueError("invalid IPv6 address format: missing ']'")

host = addr[: end + 1]

# this only validates the address without brackets: we still need the below checks.
# if it's invalid, immediately raise an error so we know it's not formatting issues.
if not is_valid_ipv6_address(host[1:end]):
raise ValueError(f"invalid IPv6 address: {host}")

port_str = None
if len(addr) > end + 1:
if addr[end + 1] == ":":
port_str = addr[end + 2 :]
else:
raise ValueError("received IPv6 address format: expected ':' after ']'")

if not port_str:
raise ValueError(
"a port must be specified in IPv6 address (format: [ipv6]:port)"
)

try:
port = int(port_str)
except ValueError:
raise ValueError(f"invalid port in IPv6 address: '{port_str}'")
return port, host


def rank0_print(msg: str):
from sglang.srt.distributed import get_tensor_model_parallel_rank

Expand Down
237 changes: 236 additions & 1 deletion test/srt/test_server_args.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import unittest
from unittest.mock import MagicMock, patch

from sglang.srt.server_args import prepare_server_args
from sglang.srt.server_args import PortArgs, ServerArgs, prepare_server_args


class TestPrepareServerArgs(unittest.TestCase):
Expand All @@ -21,5 +22,239 @@ def test_prepare_server_args(self):
)


class TestPortArgs(unittest.TestCase):
@patch("sglang.srt.server_args.is_port_available")
@patch("sglang.srt.server_args.tempfile.NamedTemporaryFile")
def test_init_new_standard_case(self, mock_temp_file, mock_is_port_available):

mock_is_port_available.return_value = True
mock_temp_file.return_value.name = "temp_file"

server_args = MagicMock()
server_args.port = 30000
server_args.enable_dp_attention = False

port_args = PortArgs.init_new(server_args)

self.assertTrue(port_args.tokenizer_ipc_name.startswith("ipc://"))
self.assertTrue(port_args.scheduler_input_ipc_name.startswith("ipc://"))
self.assertTrue(port_args.detokenizer_ipc_name.startswith("ipc://"))
self.assertIsInstance(port_args.nccl_port, int)

@patch("sglang.srt.server_args.is_port_available")
def test_init_new_with_single_node_dp_attention(self, mock_is_port_available):

mock_is_port_available.return_value = True

server_args = MagicMock()
server_args.port = 30000
server_args.enable_dp_attention = True
server_args.nnodes = 1
server_args.dist_init_addr = None

port_args = PortArgs.init_new(server_args)

self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://127.0.0.1:"))
self.assertTrue(
port_args.scheduler_input_ipc_name.startswith("tcp://127.0.0.1:")
)
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://127.0.0.1:"))
self.assertIsInstance(port_args.nccl_port, int)

@patch("sglang.srt.server_args.is_port_available")
def test_init_new_with_dp_rank(self, mock_is_port_available):

mock_is_port_available.return_value = True

server_args = MagicMock()
server_args.port = 30000
server_args.enable_dp_attention = True
server_args.nnodes = 1
server_args.dist_init_addr = "192.168.1.1:25000"

port_args = PortArgs.init_new(server_args, dp_rank=2)

self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25006"))

self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
self.assertIsInstance(port_args.nccl_port, int)

@patch("sglang.srt.server_args.is_port_available")
def test_init_new_with_ipv4_address(self, mock_is_port_available):

mock_is_port_available.return_value = True

server_args = MagicMock()
server_args.port = 30000
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "192.168.1.1:25000"

port_args = PortArgs.init_new(server_args)

self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
self.assertTrue(
port_args.scheduler_input_ipc_name.startswith("tcp://192.168.1.1:")
)
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
self.assertIsInstance(port_args.nccl_port, int)

@patch("sglang.srt.server_args.is_port_available")
def test_init_new_with_malformed_ipv4_address(self, mock_is_port_available):

mock_is_port_available.return_value = True

server_args = MagicMock()
server_args.port = 30000
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "192.168.1.1"

with self.assertRaises(AssertionError) as context:
PortArgs.init_new(server_args)

self.assertIn(
"please provide --dist-init-addr as host:port", str(context.exception)
)

@patch("sglang.srt.server_args.is_port_available")
def test_init_new_with_malformed_ipv4_address_invalid_port(
self, mock_is_port_available
):

mock_is_port_available.return_value = True

server_args = MagicMock()
server_args.port = 30000
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "192.168.1.1:abc"

with self.assertRaises(ValueError) as context:
PortArgs.init_new(server_args)

@patch("sglang.srt.server_args.is_port_available")
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
def test_init_new_with_ipv6_address(
self, mock_is_valid_ipv6, mock_is_port_available
):

mock_is_port_available.return_value = True

server_args = MagicMock()
server_args.port = 30000
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "[2001:db8::1]:25000"

port_args = PortArgs.init_new(server_args)

self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://[2001:db8::1]:"))
self.assertTrue(
port_args.scheduler_input_ipc_name.startswith("tcp://[2001:db8::1]:")
)
self.assertTrue(
port_args.detokenizer_ipc_name.startswith("tcp://[2001:db8::1]:")
)
self.assertIsInstance(port_args.nccl_port, int)

@patch("sglang.srt.server_args.is_port_available")
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=False)
def test_init_new_with_invalid_ipv6_address(
self, mock_is_valid_ipv6, mock_is_port_available
):

mock_is_port_available.return_value = True

server_args = MagicMock()
server_args.port = 30000
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "[invalid-ipv6]:25000"

with self.assertRaises(ValueError) as context:
PortArgs.init_new(server_args)

self.assertIn("invalid IPv6 address", str(context.exception))

@patch("sglang.srt.server_args.is_port_available")
def test_init_new_with_malformed_ipv6_address_missing_bracket(
self, mock_is_port_available
):

mock_is_port_available.return_value = True

server_args = MagicMock()
server_args.port = 30000
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "[2001:db8::1:25000"

with self.assertRaises(ValueError) as context:
PortArgs.init_new(server_args)

self.assertIn("invalid IPv6 address format", str(context.exception))

@patch("sglang.srt.server_args.is_port_available")
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
def test_init_new_with_malformed_ipv6_address_missing_port(
self, mock_is_valid_ipv6, mock_is_port_available
):

mock_is_port_available.return_value = True

server_args = MagicMock()
server_args.port = 30000
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "[2001:db8::1]"

with self.assertRaises(ValueError) as context:
PortArgs.init_new(server_args)

self.assertIn(
"a port must be specified in IPv6 address", str(context.exception)
)

@patch("sglang.srt.server_args.is_port_available")
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
def test_init_new_with_malformed_ipv6_address_invalid_port(
self, mock_is_valid_ipv6, mock_is_port_available
):

mock_is_port_available.return_value = True

server_args = MagicMock()
server_args.port = 30000
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "[2001:db8::1]:abcde"

with self.assertRaises(ValueError) as context:
PortArgs.init_new(server_args)

self.assertIn("invalid port in IPv6 address", str(context.exception))

@patch("sglang.srt.server_args.is_port_available")
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
def test_init_new_with_malformed_ipv6_address_wrong_separator(
self, mock_is_valid_ipv6, mock_is_port_available
):

mock_is_port_available.return_value = True

server_args = MagicMock()
server_args.port = 30000
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "[2001:db8::1]#25000"

with self.assertRaises(ValueError) as context:
PortArgs.init_new(server_args)

self.assertIn("expected ':' after ']'", str(context.exception))


if __name__ == "__main__":
unittest.main()
Loading