Skip to content

Commit

Permalink
bug fixes, add support for custom flood handling function, code refac…
Browse files Browse the repository at this point in the history
…torings.
  • Loading branch information
AmarnathCJD committed Aug 3, 2024
1 parent 451bead commit 281cd99
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 103 deletions.
10 changes: 0 additions & 10 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -782,13 +782,3 @@ func (*errorSessionConfigsChanged) Error() string {
func (*errorSessionConfigsChanged) CRC() uint32 {
return 0x00000000
}

type errorReconnectRequired struct{}

func (*errorReconnectRequired) Error() string {
return "session configuration was changed, need to repeat request"
}

func (*errorReconnectRequired) CRC() uint32 {
return 0x00000000
}
20 changes: 10 additions & 10 deletions internal/transport/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import (
)

type tcpConn struct {
cancelReader *CancelableReader
conn *net.TCPConn
timeout time.Duration
reader *Reader
conn *net.TCPConn
timeout time.Duration
}

type TCPConnConfig struct {
Expand All @@ -39,9 +39,9 @@ func NewTCP(cfg TCPConnConfig) (Conn, error) {
}

return &tcpConn{
cancelReader: NewCancelableReader(cfg.Ctx, conn),
conn: conn,
timeout: cfg.Timeout,
reader: NewReader(cfg.Ctx, conn),
conn: conn,
timeout: cfg.Timeout,
}, nil
}

Expand All @@ -51,9 +51,9 @@ func newSocksTCP(cfg TCPConnConfig) (Conn, error) {
return nil, err
}
return &tcpConn{
cancelReader: NewCancelableReader(cfg.Ctx, conn),
conn: conn.(*net.TCPConn),
timeout: cfg.Timeout,
reader: NewReader(cfg.Ctx, conn),
conn: conn.(*net.TCPConn),
timeout: cfg.Timeout,
}, nil
}

Expand All @@ -73,7 +73,7 @@ func (t *tcpConn) Read(b []byte) (int, error) {
}
}

n, err := t.cancelReader.Read(b)
n, err := t.reader.Read(b)
if err != nil {
if e, ok := err.(*net.OpError); ok {
if e.Err.Error() == "i/o timeout" {
Expand Down
27 changes: 7 additions & 20 deletions internal/transport/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"strconv"
)

// CancelableReader is a wrapper around io.Reader, that allows to cancel read operation.
type CancelableReader struct {
// Reader is a wrapper around io.Reader, that allows to cancel read operation.
type Reader struct {
ctx context.Context
data chan []byte

Expand All @@ -20,7 +20,7 @@ type CancelableReader struct {
r io.Reader
}

func (c *CancelableReader) begin() {
func (c *Reader) begin() {
for {
select {
case sizeWant := <-c.sizeWant:
Expand All @@ -43,20 +43,7 @@ func (c *CancelableReader) begin() {
}
}

// func isClosed(ch <-chan int) bool {
// select {
// case <-ch:
// return true
// default:
// }
// return false
// }

func (c *CancelableReader) Read(p []byte) (int, error) {
// if isClosed(c.sizeWant) {
// return 0, c.err
// }

func (c *Reader) Read(p []byte) (int, error) {
select {
case <-c.ctx.Done():
return 0, c.ctx.Err()
Expand All @@ -75,7 +62,7 @@ func (c *CancelableReader) Read(p []byte) (int, error) {
}
}

func (c *CancelableReader) ReadByte() (byte, error) {
func (c *Reader) ReadByte() (byte, error) {
b := make([]byte, 1)

n, err := c.Read(b)
Expand All @@ -89,8 +76,8 @@ func (c *CancelableReader) ReadByte() (byte, error) {
return b[0], nil
}

func NewCancelableReader(ctx context.Context, r io.Reader) *CancelableReader {
c := &CancelableReader{
func NewReader(ctx context.Context, r io.Reader) *Reader {
c := &Reader{
r: r,
ctx: ctx,
data: make(chan []byte),
Expand Down
102 changes: 43 additions & 59 deletions mtproto.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ type MTProto struct {
serviceChannel chan tl.Object
serviceModeActivated bool

authKey404 []int64
shouldTransfer bool
authKey404 []int64

Logger *utils.Logger

serverRequestHandlers []func(i any) bool
floodHandler func(err error) bool
}

type Config struct {
Expand All @@ -84,6 +84,7 @@ type Config struct {
SessionStorage session.SessionLoader
MemorySession bool
AppID int32
FloodHandler func(err error) bool

ServerHost string
PublicKey *rsa.PublicKey
Expand Down Expand Up @@ -130,6 +131,7 @@ func NewMTProto(c Config) (*MTProto, error) {
memorySession: c.MemorySession,
appID: c.AppID,
proxy: c.Proxy,
floodHandler: func(err error) bool { return false },
}

mtproto.Logger.Debug("initializing mtproto...")
Expand All @@ -142,7 +144,9 @@ func NewMTProto(c Config) (*MTProto, error) {
return nil, errors.Wrap(err, "loading auth")
}

//go mtproto.checkBreaking()
if c.FloodHandler != nil {
mtproto.floodHandler = c.FloodHandler
}

return mtproto, nil
}
Expand Down Expand Up @@ -209,10 +213,6 @@ func (m *MTProto) ImportAuth(stringSession string) (bool, error) {
return true, nil
}

func (m *MTProto) SetTransfer(transfer bool) {
m.shouldTransfer = transfer
}

func (m *MTProto) GetDC() int {
return utils.SearchAddr(m.Addr)
}
Expand All @@ -231,10 +231,10 @@ func (m *MTProto) ReconnectToNewDC(dc int) (*MTProto, error) {
}
newAddr := utils.GetAddr(dc)
if newAddr == "" {
return nil, errors.New("invalid data center id provided")
return nil, errors.New("dc_id not found")
}

m.Logger.Debug("migrating to new dc... (dc: " + strconv.Itoa(dc) + ")")
m.Logger.Debug("migrating to new DC... dc-" + strconv.Itoa(dc))
m.sessionStorage.Delete()
m.Logger.Debug("deleted old auth key file")
cfg := Config{
Expand Down Expand Up @@ -342,8 +342,6 @@ func (m *MTProto) connect(ctx context.Context) error {
return fmt.Errorf("creating transport: %w", err)
}

m.SetTransfer(true)

go closeOnCancel(ctx, m.transport)
return nil
}
Expand All @@ -367,20 +365,18 @@ func (m *MTProto) makeRequest(data tl.Object, expectedTypes ...reflect.Type) (an
response := <-resp
switch r := response.(type) {
case *objects.RpcError:
// if err := RpcErrorToNative(r).(*ErrResponseCode); strings.Contains(err.Message, "FLOOD_WAIT_") {
// m.Logger.Info("flood wait detected on '" + strings.ReplaceAll(reflect.TypeOf(data).Elem().Name(), "Params", "") + fmt.Sprintf("' request. sleeping for %s", (time.Duration(realErr.AdditionalInfo.(int))*time.Second).String()))
// time.Sleep(time.Duration(realErr.AdditionalInfo.(int)) * time.Second)
// return m.makeRequest(data, expectedTypes...) TODO: implement flood wait correctly
// }
if err := RpcErrorToNative(r).(*ErrResponseCode); strings.Contains(err.Message, "FLOOD_WAIT_") || strings.Contains(err.Message, "FLOOD_PREMIUM_WAIT_") {
if done := m.floodHandler(err); !done {
return nil, RpcErrorToNative(r)
} else {
return m.makeRequest(data, expectedTypes...)
}
}
return nil, RpcErrorToNative(r)

case *errorSessionConfigsChanged:
m.Logger.Debug("session configs changed, resending request")
return m.makeRequest(data, expectedTypes...)

case *errorReconnectRequired:
m.Logger.Info("req info: " + fmt.Sprintf("%T", data))
return nil, errors.New("required to reconnect!")
}

return tl.UnwrapNativeTypes(response), nil
Expand All @@ -402,12 +398,6 @@ func (m *MTProto) Disconnect() error {
m.stopRoutines()
m.tcpActive = false

// for _, v := range m.responseChannels.Keys() {
// ch, _ := m.responseChannels.Get(v)
// ch <- &errorReconnectRequired{}
// }

// m.responseChannels.Close()
return nil
}

Expand Down Expand Up @@ -436,7 +426,6 @@ func (m *MTProto) Reconnect(WithLogs bool) error {
PingID: 123456789,
})

//m.MakeRequest(&utils.UpdatesGetStateParams{}) // to ask the server to send the updates
return errors.Wrap(err, "recreating connection")
}

Expand Down Expand Up @@ -506,21 +495,9 @@ func (m *MTProto) startReadingResponses(ctx context.Context) {
switch e := err.(type) {
case *ErrResponseCode:
if e.Code == 4294966892 {
if m.authKey404 == nil || len(m.authKey404) == 0 {
m.authKey404 = []int64{1, time.Now().Unix()}
} else {
if time.Now().Unix()-m.authKey404[1] < 30 { // repeated failures
m.authKey404[0]++
} else {
m.authKey404[0] = 1
}
m.authKey404[1] = time.Now().Unix()
}
if m.authKey404[0] > 3 {
panic("[AUTH_KEY_INVALID] the auth key is invalid and needs to be reauthenticated (code -404)")
} else {
m.Logger.Error(errors.New("(retry: " + strconv.FormatInt(m.authKey404[0], 10) + ") [AUTH_KEY_INVALID] the auth key is invalid and needs to be reauthenticated (code -404)"))
}
m.handle404Error()
} else {
m.Logger.Debug(errors.New("[RESPONSE_ERROR_CODE] - " + e.Error()))
}
case *transport.ErrCode:
m.Logger.Error(errors.New("[TRANSPORT_ERROR_CODE] - " + e.Error()))
Expand All @@ -537,6 +514,30 @@ func (m *MTProto) startReadingResponses(ctx context.Context) {
}()
}

func (m *MTProto) handle404Error() {
if m.authKey404 == nil || len(m.authKey404) == 0 {
m.authKey404 = []int64{1, time.Now().Unix()}
} else {
if time.Now().Unix()-m.authKey404[1] < 30 { // repeated failures
m.authKey404[0]++
} else {
m.authKey404[0] = 1
}
m.authKey404[1] = time.Now().Unix()
}
if m.authKey404[0] == 4 {
m.Logger.Error(errors.New("(last retry: 4) reconnecting due to [AUTH_KEY_INVALID] (code -404)"))
err := m.Reconnect(false)
if err != nil {
m.Logger.Error(errors.Wrap(err, "reconnecting"))
}
} else if m.authKey404[0] > 4 {
panic("[AUTH_KEY_INVALID] the auth key is invalid and needs to be reauthenticated (code -404)")
} else {
m.Logger.Error(errors.New("(retry: " + strconv.FormatInt(m.authKey404[0], 10) + ") [AUTH_KEY_INVALID] the auth key is invalid and needs to be reauthenticated (code -404)"))
}
}

func (m *MTProto) readMsg() error {
if m.transport == nil {
return errors.New("must setup connection before reading messages")
Expand Down Expand Up @@ -756,20 +757,3 @@ func closeOnCancel(ctx context.Context, c io.Closer) {
c.Close()
}()
}

// func (m *MTProto) checkBreaking() {
// ticker := time.NewTicker(2 * time.Second)
// defer ticker.Stop()

// for range ticker.C {
// if m.shouldTransfer && !m.TcpActive() {
// //m.CreateConnection(false)
// for i := 0; i < 2; i++ {
// //_, err := m.MakeRequest(&utils.UpdatesGetStateParams{})
// //if err == nil {
// break
// }
// }
// }
// }
// }
8 changes: 4 additions & 4 deletions telegram/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ import (
)

const (
// DefaultDC is the default data center id
// The Initial DC to connect to, before auth
DefaultDataCenter = 4
DisconnectExportedAfter = 5 * time.Minute
)

// TODO: fix session file issue

type clientData struct {
appID int32
appHash string
Expand Down Expand Up @@ -88,6 +86,7 @@ type ClientConfig struct {
TestMode bool
LogLevel string
Proxy *url.URL
FloodHandler func(err error) bool
}

type Session struct {
Expand Down Expand Up @@ -307,6 +306,8 @@ func (c *Client) Start() error {
} else if err != nil {
return err
}

c.stopCh = make(chan struct{}) // reset the stop channel
return nil
}

Expand Down Expand Up @@ -544,7 +545,6 @@ func (c *Client) Idle() {
func (c *Client) Stop() error {
close(c.stopCh)
go c.cleanExportedSenders()
c.MTProto.SetTransfer(false) // to stop connection break check.
return c.MTProto.Terminate()
}

Expand Down

0 comments on commit 281cd99

Please sign in to comment.