Skip to content

Commit dea21c3

Browse files
committed
chore: optimize trojan codebase
1 parent 0493a99 commit dea21c3

File tree

2 files changed

+30
-29
lines changed

2 files changed

+30
-29
lines changed

tunnels/trojan/server.go

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ type UDPPacket struct {
2525

2626
type Server struct {
2727
PasswordHashes []string
28-
underlying net.Interface
28+
Underlying net.Interface
2929

3030
initOnce sync.Once
3131
}
3232

3333
func (s *Server) Handle(conn *tcp.Conn) tcp.SerRet {
3434
s.initOnce.Do(func() {
35-
if s.underlying == nil {
36-
s.underlying = net.DefaultRouteTable
35+
if s.Underlying == nil {
36+
s.Underlying = net.DefaultRouteTable
3737
}
3838
})
3939

@@ -66,15 +66,17 @@ func (s *Server) Handle(conn *tcp.Conn) tcp.SerRet {
6666
switch metadata.Command {
6767
case 0x1: // Connect
6868
var target net.Conn
69+
var address string
6970
switch metadata.Address.AddressType {
7071
case IPv4, IPv6:
71-
target, err = s.underlying.Dial("tcp", gonet.JoinHostPort(metadata.Address.IP.String(), strconv.Itoa(metadata.Address.Port)))
72-
log.Verboseln("[TROJAN]", "Dialed TCP", metadata.Address.IP.String(), metadata.Address.Port)
72+
address = gonet.JoinHostPort(metadata.Address.IP.String(), strconv.Itoa(metadata.Address.Port))
7373
case DomainName:
74-
target, err = s.underlying.Dial("tcp", gonet.JoinHostPort(metadata.Address.DomainName, strconv.Itoa(metadata.Address.Port)))
75-
log.Verboseln("[TROJAN]", "Dialed TCP", metadata.Address.DomainName, metadata.Address.Port)
74+
address = gonet.JoinHostPort(metadata.Address.DomainName, strconv.Itoa(metadata.Address.Port))
7675
}
7776

77+
target, err = s.Underlying.Dial("tcp", address)
78+
log.Verboseln("[TROJAN]", "Dialed TCP", address)
79+
7880
if err != nil {
7981
log.Verboseln("[TROJAN]", "Dial failed", err)
8082
return tcp.Close
@@ -160,43 +162,30 @@ func (s *Server) Handle(conn *tcp.Conn) tcp.SerRet {
160162
}
161163

162164
// Determine target address
163-
var targetAddr *gonet.UDPAddr
164-
var addrKey string
165+
var address string
165166

166167
switch addr.AddressType {
167168
case IPv4, IPv6:
168-
targetAddr = &gonet.UDPAddr{
169-
IP: addr.IP,
170-
Port: addr.Port,
171-
}
172-
addrKey = targetAddr.String()
169+
address = gonet.JoinHostPort(addr.IP.String(), strconv.Itoa(addr.Port))
173170
case DomainName:
174-
// Resolve domain name
175-
resolvedAddr, err := gonet.ResolveUDPAddr("udp", gonet.JoinHostPort(addr.DomainName, strconv.Itoa(addr.Port)))
176-
if err != nil {
177-
log.Verboseln("[TROJAN]", "failed to resolve domain", addr.DomainName, err)
178-
continue
179-
}
180-
targetAddr = resolvedAddr
181-
addrKey = gonet.JoinHostPort(addr.DomainName, strconv.Itoa(addr.Port))
171+
address = gonet.JoinHostPort(addr.DomainName, strconv.Itoa(addr.Port))
182172
}
183173

184-
// Get or create UDP connection for this target
185-
target, exists := connTable[addrKey]
174+
target, exists := connTable[address]
186175
if !exists {
187-
_target, err := s.underlying.Dial("udp", gonet.JoinHostPort(targetAddr.IP.String(), strconv.Itoa(targetAddr.Port)))
176+
_target, err := s.Underlying.Dial("udp", address)
188177

189178
if err != nil {
190-
log.Verboseln("[TROJAN]", "DialUDP failed for", addrKey, err)
179+
log.Verboseln("[TROJAN]", "DialUDP failed for", address, err)
191180
continue
192181
}
193182

194183
target = _target.(*gonet.UDPConn)
195184

196-
connTable[addrKey] = target
185+
connTable[address] = target
197186
go downlink(target)
198187

199-
log.Verboseln("[TROJAN]", "Created new UDP connection to", addrKey)
188+
log.Verboseln("[TROJAN]", "new UDP connection to", address)
200189
}
201190

202191
n, err := io.ReadFull(r, buf[:length])
@@ -210,7 +199,7 @@ func (s *Server) Handle(conn *tcp.Conn) tcp.SerRet {
210199

211200
if err != nil {
212201
target.Close()
213-
delete(connTable, addrKey)
202+
delete(connTable, address)
214203
continue
215204
}
216205

ui/builtin.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,10 @@ var _builtin_refs_assertions = map[string]Assert{
943943
"_": {Type: "string"},
944944
},
945945
},
946+
"interface": {
947+
Type: "ptr",
948+
Required: false,
949+
},
946950
},
947951
},
948952
"builtin::net::interface::sys": {
@@ -1712,12 +1716,20 @@ var _builtin_refs = map[string]Inst{
17121716
},
17131717
"builtin::trojan::server": func(spec *ArgNode) (any, error) {
17141718
passwords := spec.MustGet("passwords").ToStringList()
1719+
_interface := spec.MustGet("interface")
1720+
var underlying net.Interface
1721+
1722+
if _interface != nil {
1723+
underlying = _interface.Value.(net.Interface)
1724+
}
1725+
17151726
for i, password := range passwords {
17161727
sum := sha256.Sum224([]byte(password))
17171728
passwords[i] = hex.EncodeToString(sum[:])
17181729
}
17191730
return &trojan.Server{
17201731
PasswordHashes: passwords,
1732+
Underlying: underlying,
17211733
}, nil
17221734
},
17231735
"builtin::net::interface::sys": func(*ArgNode) (any, error) {

0 commit comments

Comments
 (0)