Skip to content

Commit

Permalink
enhancement(hy2): support human-readable bandwidth configuration
Browse files Browse the repository at this point in the history
Fixes #665
  • Loading branch information
douglarek committed Sep 29, 2024
1 parent 2a8b537 commit 83f3827
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 18 deletions.
4 changes: 2 additions & 2 deletions component/outbound/dialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ func NewGlobalOption(global *config.Global, log *logrus.Logger) *GlobalOption {
AllowInsecure: global.AllowInsecure,
TlsImplementation: global.TlsImplementation,
UtlsImitate: global.UtlsImitate,
BandwidthMaxTx: global.BandwidthMaxTx,
BandwidthMaxRx: global.BandwidthMaxRx},
BandwidthMaxTx: global.BandwidthMaxTx.Uint64(),
BandwidthMaxRx: global.BandwidthMaxRx.Uint64()},
Log: log,
TcpCheckOptionRaw: TcpCheckOptionRaw{Raw: global.TcpCheckUrl, Log: log, ResolverNetwork: common.MagicNetwork("udp", global.SoMarkFromDae, global.Mptcp), Method: global.TcpCheckHttpMethod},
CheckDnsOptionRaw: CheckDnsOptionRaw{Raw: global.UdpCheckDns, ResolverNetwork: common.MagicNetwork("udp", global.SoMarkFromDae, global.Mptcp), Somark: global.SoMarkFromDae},
Expand Down
7 changes: 6 additions & 1 deletion component/outbound/dialer_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/daeuniverse/dae/pkg/logger"
"github.com/daeuniverse/outbound/pkg/fastrand"
"github.com/sirupsen/logrus"
)

const (
Expand All @@ -26,7 +27,11 @@ var TestNetworkType = &dialer.NetworkType{
IsDns: false,
}

var log = logger.NewLogger("trace", false, nil)
var log = logrus.New()

func init() {
logger.SetLogger(log, "trace", false, nil)
}

func newDirectDialer(option *dialer.GlobalOption, fullcone bool) *dialer.Dialer {
_d, p := dialer.NewDirectDialer(option, true)
Expand Down
18 changes: 9 additions & 9 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"reflect"
"time"

"github.com/daeuniverse/dae/pkg/bandwidth"
"github.com/daeuniverse/dae/pkg/config_parser"
)

Expand Down Expand Up @@ -37,15 +38,14 @@ type Global struct {
EnableLocalTcpFastRedirect bool `mapstructure:"enable_local_tcp_fast_redirect" default:"false"`
AutoConfigKernelParameter bool `mapstructure:"auto_config_kernel_parameter" default:"false"`
// DEPRECATED: not used as of https://github.com/daeuniverse/dae/pull/458
AutoConfigFirewallRule bool `mapstructure:"auto_config_firewall_rule" default:"false"`
SniffingTimeout time.Duration `mapstructure:"sniffing_timeout" default:"100ms"`
TlsImplementation string `mapstructure:"tls_implementation" default:"tls"`
UtlsImitate string `mapstructure:"utls_imitate" default:"chrome_auto"`
PprofPort uint16 `mapstructure:"pprof_port" default:"0"`
Mptcp bool `mapstructure:"mptcp" default:"false"`
// TODO: support input in human-readable format (e.g., 100Mbps, 1Gbps)
BandwidthMaxTx uint64 `mapstructure:"bandwidth_max_tx" default:"0"`
BandwidthMaxRx uint64 `mapstructure:"bandwidth_max_rx" default:"0"`
AutoConfigFirewallRule bool `mapstructure:"auto_config_firewall_rule" default:"false"`
SniffingTimeout time.Duration `mapstructure:"sniffing_timeout" default:"100ms"`
TlsImplementation string `mapstructure:"tls_implementation" default:"tls"`
UtlsImitate string `mapstructure:"utls_imitate" default:"chrome_auto"`
PprofPort uint16 `mapstructure:"pprof_port" default:"0"`
Mptcp bool `mapstructure:"mptcp" default:"false"`
BandwidthMaxTx bandwidth.UnitValue `mapstructure:"bandwidth_max_tx" default:"0"`
BandwidthMaxRx bandwidth.UnitValue `mapstructure:"bandwidth_max_rx" default:"0"`
}

type Utls struct {
Expand Down
16 changes: 13 additions & 3 deletions config/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"

"github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/pkg/bandwidth"
"github.com/daeuniverse/dae/pkg/config_parser"
)

Expand Down Expand Up @@ -153,9 +154,18 @@ func ParamParser(to reflect.Value, section *config_parser.Section, ignoreType []
field.Val.Set(reflect.Append(field.Val, vPointerNew.Elem()))
}
default:
// Field is not interface{}, we can decode.
if !common.FuzzyDecode(field.Val.Addr().Interface(), itemVal.Val) {
return fmt.Errorf("failed to parse \"%v\": value \"%v\" cannot be convert to %v", itemVal.Key, itemVal.Val, field.Val.Type().String())
// Special types like bandwidth.UnitValue require custom decoding as common.FuzzyDecode relies on the field's default value type.
if field.Val.Type() == reflect.TypeOf(bandwidth.UnitValue(0)) {
value, err := bandwidth.ConvBandwidth(itemVal.Val)
if err != nil {
return fmt.Errorf("failed to parse \"%v\": %v", itemVal.Key, err)
}
field.Val.Set(reflect.ValueOf(bandwidth.UnitValue(value)))
} else {
// Field is not interface{}, we can decode.
if !common.FuzzyDecode(field.Val.Addr().Interface(), itemVal.Val) {
return fmt.Errorf("failed to parse \"%v\": value \"%v\" cannot be convert to %v", itemVal.Key, itemVal.Val, field.Val.Type().String())
}
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions example.dae
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ global {
mptcp: false

# The maximum bandwidth for accessing the Internet. It is useful for some specific protocols (e.g., Hysteria2),
# which will perform better with bandwith information provided. The unit is **byte** per second.
bandwidth_max_tx: 26214400 # 200Mbps == 25MB/s == 26214400 B/s uplink
bandwidth_max_rx: 131072000 # 1Gbps == 125MB/s == 131072000 B/s downlink
# which will perform better with bandwith information provided. The unit can be b, kb, mb, gb, tb or bytes per second.
# supported formats: https://v2.hysteria.network/docs/advanced/Full-Client-Config/#bandwidth
bandwidth_max_tx: '200 mbps' # uplink, or '200 m' or '200 mb' or '200 mbps' or 25000000(which is 200/8*1000*1000)
bandwidth_max_rx: '1 gbps' # downlink, or '1 g' or '1 gb' or '1 gbps' or 125000000(which is 1000/8*1000*1000)
}

# Subscriptions defined here will be resolved as nodes and merged as a part of the global node pool.
Expand Down
94 changes: 94 additions & 0 deletions pkg/bandwidth/bpsconv.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2024, daeuniverse Organization <[email protected]>
*/

package bandwidth

import (
"errors"
"fmt"
"strconv"
"strings"
)

// UnitValue is a wrapper for uint64 to support custom unmarshaler.
type UnitValue uint64

// UnmarshalText parses a string into a UnitValue.
func (b *UnitValue) UnmarshalText(text []byte) error {
value, err := ConvBandwidth(string(text))
if err != nil {
return err
}
*b = UnitValue(value)
return nil
}

// Uint64 returns the underlying uint64 value.
func (b UnitValue) Uint64() uint64 {
return uint64(b)
}

// /////// Code reference from https://github.com/apernet/hysteria/blob/21ea2a0/app/internal/utils/bpsconv.go
const (
Byte = 1
Kilobyte = Byte * 1000
Megabyte = Kilobyte * 1000
Gigabyte = Megabyte * 1000
Terabyte = Gigabyte * 1000
)

// StringToBps converts a string to a bandwidth value in bytes per second.
// E.g. "100 Mbps", "512 kbps", "1g" are all valid.
func StringToBps(s string) (uint64, error) {
s = strings.ToLower(strings.TrimSpace(s))
spl := 0
for i, c := range s {
if c < '0' || c > '9' {
spl = i
break
}
}
if spl == 0 {
// No unit or no value
return 0, errors.New("invalid format")
}
v, err := strconv.ParseUint(s[:spl], 10, 64)
if err != nil {
return 0, err
}
unit := strings.TrimSpace(s[spl:])

switch strings.ToLower(unit) {
case "b", "bps":
return v * Byte / 8, nil
case "k", "kb", "kbps":
return v * Kilobyte / 8, nil
case "m", "mb", "mbps":
return v * Megabyte / 8, nil
case "g", "gb", "gbps":
return v * Gigabyte / 8, nil
case "t", "tb", "tbps":
return v * Terabyte / 8, nil
default:
return 0, errors.New("unsupported unit")
}
}

// ConvBandwidth handles both string and int types for bandwidth.
// When using string, it will be parsed as a bandwidth string with units.
// When using int, it will be parsed as a raw bandwidth in bytes per second.
// It does NOT support float types.
func ConvBandwidth(bw interface{}) (uint64, error) {
switch bwT := bw.(type) {
case string:
return StringToBps(bwT)
case int:
return uint64(bwT), nil
default:
return 0, fmt.Errorf("invalid type %T for bandwidth", bwT)
}
}

// /////// reference end

0 comments on commit 83f3827

Please sign in to comment.