Skip to content

Commit

Permalink
Add rate limiting capabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
vysotskylev committed Apr 14, 2022
1 parent c6485a6 commit 3986229
Show file tree
Hide file tree
Showing 12 changed files with 576 additions and 48 deletions.
3 changes: 3 additions & 0 deletions config_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ keys:
port: 9000
cipher: chacha20-ietf-poly1305
secret: Secret0
rate_limit: 128000

- id: user-1
port: 9000
cipher: chacha20-ietf-poly1305
secret: Secret1
rate_limit: 128000

- id: user-2
port: 9001
cipher: chacha20-ietf-poly1305
secret: Secret2
rate_limit: 128000
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ require (
github.com/prometheus/procfs v0.1.3 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 // indirect
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 // indirect
google.golang.org/protobuf v1.23.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 h1:AvbQYmiaaaza3cW3QXRyPo5kYgpFIzOAfeAAN7m3qQ4=
golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs=
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
Expand Down
205 changes: 198 additions & 7 deletions integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) {
return conn, &running
}

func makeLimiter(cipherList service.CipherList) service.TrafficLimiter {
c := service.MakeTestTrafficLimiterConfig(cipherList)
return service.NewTrafficLimiter(&c)
}

func TestTCPEcho(t *testing.T) {
echoListener, echoRunning := startTCPEchoServer(t)

Expand All @@ -111,7 +116,7 @@ func TestTCPEcho(t *testing.T) {
}
replayCache := service.NewReplayCache(5)
const testTimeout = 200 * time.Millisecond
proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout)
proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout, makeLimiter(cipherList))
proxy.SetTargetIPValidator(allowAll)
go proxy.Serve(proxyListener)

Expand Down Expand Up @@ -164,6 +169,192 @@ func TestTCPEcho(t *testing.T) {
echoRunning.Wait()
}

func TestTrafficLimiterTCP(t *testing.T) {
echoListener, echoRunning := startTCPEchoServer(t)

proxyListener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
t.Fatalf("ListenTCP failed: %v", err)
}
secrets := ss.MakeTestSecrets(1)
cipherList, err := service.MakeTestCiphers(secrets)
if err != nil {
t.Fatal(err)
}
replayCache := service.NewReplayCache(5)
const testTimeout = 5 * time.Second

key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*service.CipherEntry).ID
trafficLimiter := service.NewTrafficLimiter(&service.TrafficLimiterConfig{
KeyToRateLimit: map[string]int{
key: 1000,
},
})

proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout, trafficLimiter)
proxy.SetTargetIPValidator(allowAll)
go proxy.Serve(proxyListener)

proxyHost, proxyPort, err := net.SplitHostPort(proxyListener.Addr().String())
if err != nil {
t.Fatal(err)
}
portNum, err := strconv.Atoi(proxyPort)
if err != nil {
t.Fatal(err)
}
client, err := client.NewClient(proxyHost, portNum, secrets[0], ss.TestCipher)
if err != nil {
t.Fatalf("Failed to create ShadowsocksClient: %v", err)
}

doWriteRead := func(N int, repeats int) time.Duration {
up := ss.MakeTestPayload(N)
conn, err := client.DialTCP(nil, echoListener.Addr().String())
defer conn.Close()
if err != nil {
t.Fatalf("ShadowsocksClient.DialTCP failed: %v", err)
}
start := time.Now()
down := make([]byte, N)

for i := 0; i < repeats; i++ {
n, err := conn.Write(up)
if err != nil {
t.Fatal(err)
}
if n != N {
t.Fatalf("Tried to upload %d bytes, but only sent %d", N, n)
}

n, err = io.ReadFull(conn, down)
if err != nil && err != io.EOF {
t.Fatal(err)
}
if n != N {
t.Fatalf("Expected to download %d bytes, but only received %d", N, n)
}

if !bytes.Equal(up, down) {
t.Fatal("Echo mismatch")
}
}

return time.Now().Sub(start)
}

period1 := doWriteRead(200, 4)
if period1 < 500*time.Millisecond {
t.Fatalf("Write-read loop is too fast")
}

time.Sleep(1 * time.Second)

period2 := doWriteRead(200, 2)
if period2 > 100*time.Millisecond {
t.Fatalf("Write-read loop is too slow")
}

period3 := doWriteRead(500, 2)
if period3 < 500*time.Millisecond {
t.Fatalf("Write-read loop is too fast")
}

proxy.Stop()
echoListener.Close()
echoRunning.Wait()
}

func TestTrafficLimiterUDP(t *testing.T) {
echoConn, echoRunning := startUDPEchoServer(t)

proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
t.Fatalf("ListenTCP failed: %v", err)
}
secrets := ss.MakeTestSecrets(1)
cipherList, err := service.MakeTestCiphers(secrets)
if err != nil {
t.Fatal(err)
}
testMetrics := &fakeUDPMetrics{fakeLocation: "QQ"}

key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*service.CipherEntry).ID
trafficLimiter := service.NewTrafficLimiter(&service.TrafficLimiterConfig{
KeyToRateLimit: map[string]int{
key: 1000,
},
})

proxy := service.NewUDPService(time.Hour, cipherList, testMetrics, trafficLimiter)
proxy.SetTargetIPValidator(allowAll)
go proxy.Serve(proxyConn)

proxyHost, proxyPort, err := net.SplitHostPort(proxyConn.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
portNum, err := strconv.Atoi(proxyPort)
if err != nil {
t.Fatal(err)
}
client, err := client.NewClient(proxyHost, portNum, secrets[0], ss.TestCipher)
if err != nil {
t.Fatalf("Failed to create ShadowsocksClient: %v", err)
}
conn, err := client.ListenUDP(nil)
if err != nil {
t.Fatalf("ShadowsocksClient.ListenUDP failed: %v", err)
}

run := func(N int, expectReadError bool) {
up := ss.MakeTestPayload(N)
n, err := conn.WriteTo(up, echoConn.LocalAddr())
if err != nil {
t.Fatal(err)
}
if n != N {
t.Fatalf("Tried to upload %d bytes, but only sent %d", N, n)
}

conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))

down := make([]byte, N)
n, addr, err := conn.ReadFrom(down)
if err != nil {
if !expectReadError {
t.Fatalf("Unexpected read error: %v", err)
}
return
} else {
if expectReadError {
t.Fatalf("Expected read error")
}
}
if n != N {
t.Fatalf("Tried to download %d bytes, but only sent %d", N, n)
}
if addr.String() != echoConn.LocalAddr().String() {
t.Errorf("Reported address mismatch: %s != %s", addr.String(), echoConn.LocalAddr().String())
}

if !bytes.Equal(up, down) {
t.Fatal("Echo mismatch")
}
}

for i := 0; i < 3; i++ {
run(300, false)
run(300, true)
time.Sleep(time.Second)
}

conn.Close()
echoConn.Close()
echoRunning.Wait()
proxy.GracefulStop()
}

type statusMetrics struct {
metrics.NoOpMetrics
sync.Mutex
Expand All @@ -184,7 +375,7 @@ func TestRestrictedAddresses(t *testing.T) {
require.NoError(t, err)
const testTimeout = 200 * time.Millisecond
testMetrics := &statusMetrics{}
proxy := service.NewTCPService(cipherList, nil, testMetrics, testTimeout)
proxy := service.NewTCPService(cipherList, nil, testMetrics, testTimeout, makeLimiter(cipherList))
go proxy.Serve(proxyListener)

proxyHost, proxyPort, err := net.SplitHostPort(proxyListener.Addr().String())
Expand Down Expand Up @@ -266,7 +457,7 @@ func TestUDPEcho(t *testing.T) {
t.Fatal(err)
}
testMetrics := &fakeUDPMetrics{fakeLocation: "QQ"}
proxy := service.NewUDPService(time.Hour, cipherList, testMetrics)
proxy := service.NewUDPService(time.Hour, cipherList, testMetrics, makeLimiter(cipherList))
proxy.SetTargetIPValidator(allowAll)
go proxy.Serve(proxyConn)

Expand Down Expand Up @@ -363,7 +554,7 @@ func BenchmarkTCPThroughput(b *testing.B) {
b.Fatal(err)
}
const testTimeout = 200 * time.Millisecond
proxy := service.NewTCPService(cipherList, nil, &metrics.NoOpMetrics{}, testTimeout)
proxy := service.NewTCPService(cipherList, nil, &metrics.NoOpMetrics{}, testTimeout, makeLimiter(cipherList))
proxy.SetTargetIPValidator(allowAll)
go proxy.Serve(proxyListener)

Expand Down Expand Up @@ -430,7 +621,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) {
}
replayCache := service.NewReplayCache(service.MaxCapacity)
const testTimeout = 200 * time.Millisecond
proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout)
proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout, makeLimiter(cipherList))
proxy.SetTargetIPValidator(allowAll)
go proxy.Serve(proxyListener)

Expand Down Expand Up @@ -505,7 +696,7 @@ func BenchmarkUDPEcho(b *testing.B) {
if err != nil {
b.Fatal(err)
}
proxy := service.NewUDPService(time.Hour, cipherList, &metrics.NoOpMetrics{})
proxy := service.NewUDPService(time.Hour, cipherList, &metrics.NoOpMetrics{}, makeLimiter(cipherList))
proxy.SetTargetIPValidator(allowAll)
go proxy.Serve(proxyConn)

Expand Down Expand Up @@ -554,7 +745,7 @@ func BenchmarkUDPManyKeys(b *testing.B) {
if err != nil {
b.Fatal(err)
}
proxy := service.NewUDPService(time.Hour, cipherList, &metrics.NoOpMetrics{})
proxy := service.NewUDPService(time.Hour, cipherList, &metrics.NoOpMetrics{}, makeLimiter(cipherList))
proxy.SetTargetIPValidator(allowAll)
go proxy.Serve(proxyConn)

Expand Down
31 changes: 23 additions & 8 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ type SSServer struct {
ports map[int]*ssPort
}

func (s *SSServer) startPort(portNum int) error {
func (s *SSServer) startPort(portNum int, trafficLimiterConfig *service.TrafficLimiterConfig) error {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: portNum})
if err != nil {
return fmt.Errorf("Failed to start TCP on port %v: %v", portNum, err)
Expand All @@ -85,9 +85,11 @@ func (s *SSServer) startPort(portNum int) error {
}
logger.Infof("Listening TCP and UDP on port %v", portNum)
port := &ssPort{cipherList: service.NewCipherList()}

limiter := service.NewTrafficLimiter(trafficLimiterConfig)
// TODO: Register initial data metrics at zero.
port.tcpService = service.NewTCPService(port.cipherList, &s.replayCache, s.m, tcpReadTimeout)
port.udpService = service.NewUDPService(s.natTimeout, port.cipherList, s.m)
port.tcpService = service.NewTCPService(port.cipherList, &s.replayCache, s.m, tcpReadTimeout, limiter)
port.udpService = service.NewUDPService(s.natTimeout, port.cipherList, s.m, limiter)
s.ports[portNum] = port
go port.tcpService.Serve(listener)
go port.udpService.Serve(packetConn)
Expand Down Expand Up @@ -120,6 +122,7 @@ func (s *SSServer) loadConfig(filename string) error {

portChanges := make(map[int]int)
portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry.
portKeyLimits := make(map[int]map[string]int)
for _, keyConfig := range config.Keys {
portChanges[keyConfig.Port] = 1
cipherList, ok := portCiphers[keyConfig.Port]
Expand All @@ -133,6 +136,13 @@ func (s *SSServer) loadConfig(filename string) error {
}
entry := service.MakeCipherEntry(keyConfig.ID, cipher, keyConfig.Secret)
cipherList.PushBack(&entry)
var keyLimits map[string]int
keyLimits, ok = portKeyLimits[keyConfig.Port]
if !ok {
keyLimits = make(map[string]int)
portKeyLimits[keyConfig.Port] = keyLimits
}
keyLimits[keyConfig.ID] = keyConfig.RateLimit
}
for port := range s.ports {
portChanges[port] = portChanges[port] - 1
Expand All @@ -143,7 +153,8 @@ func (s *SSServer) loadConfig(filename string) error {
return fmt.Errorf("Failed to remove port %v: %v", portNum, err)
}
} else if count == +1 {
if err := s.startPort(portNum); err != nil {
trafficLimiterConfig := &service.TrafficLimiterConfig{KeyToRateLimit: portKeyLimits[portNum]}
if err := s.startPort(portNum, trafficLimiterConfig); err != nil {
return fmt.Errorf("Failed to start port %v: %v", portNum, err)
}
}
Expand Down Expand Up @@ -193,10 +204,11 @@ func RunSSServer(filename string, natTimeout time.Duration, sm metrics.Shadowsoc

type Config struct {
Keys []struct {
ID string
Port int
Cipher string
Secret string
ID string
Port int
Cipher string
Secret string
RateLimit int
}
}

Expand All @@ -207,6 +219,9 @@ func readConfig(filename string) (*Config, error) {
return nil, err
}
err = yaml.Unmarshal(configData, &config)
if err != nil {
return nil, err
}
return &config, err
}

Expand Down
Loading

0 comments on commit 3986229

Please sign in to comment.