diff --git a/internal/utils/utils.go b/internal/utils/utils.go index fb40f5f9..ec678ef2 100755 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -67,10 +67,6 @@ func (*PingParams) CRC() uint32 { type UpdatesGetStateParams struct{} -func (*UpdatesGetStateParams) CRC() uint32 { - return 0xedd4882a -} - func NewMsgIDGenerator() func(timeOffset int64) int64 { var ( mu sync.Mutex diff --git a/mtproto.go b/mtproto.go index e7eaa417..5625622b 100755 --- a/mtproto.go +++ b/mtproto.go @@ -225,7 +225,7 @@ func (m *MTProto) SetAppID(appID int32) { m.appID = appID } -func (m *MTProto) ReconnectToNewDC(dc int) (*MTProto, error) { +func (m *MTProto) SwitchDc(dc int) (*MTProto, error) { if m.noRedirect { return m, nil } diff --git a/telegram/auth.go b/telegram/auth.go index 0a2a5b4f..304ece41 100755 --- a/telegram/auth.go +++ b/telegram/auth.go @@ -82,7 +82,7 @@ func (c *Client) LoginBot(botToken string) error { } if au, e := c.IsAuthorized(); !au { if dc, code := getErrorCode(e); code == 303 { - err = c.switchDc(dc) + err = c.SwitchDc(dc) if err != nil { return err } @@ -100,7 +100,7 @@ func (c *Client) SendCode(phoneNumber string) (hash string, err error) { }) if err != nil { if dc, code := getErrorCode(err); code == 303 { - err = c.switchDc(dc) + err = c.SwitchDc(dc) if err != nil { return "", err } @@ -564,7 +564,7 @@ func (q *QrToken) Wait(timeout ...int32) error { QrResponseSwitch: switch req := resp.(type) { case *AuthLoginTokenMigrateTo: - q.client.switchDc(int(req.DcID)) + q.client.SwitchDc(int(req.DcID)) resp, err = q.client.AuthImportLoginToken(req.Token) if err != nil { return err @@ -602,7 +602,7 @@ func (c *Client) QRLogin(IgnoreIDs ...int64) (*QrToken, error) { ) switch qr := qr.(type) { case *AuthLoginTokenMigrateTo: - c.switchDc(int(qr.DcID)) + c.SwitchDc(int(qr.DcID)) return c.QRLogin(IgnoreIDs...) case *AuthLoginTokenObj: qrToken = qr.Token diff --git a/telegram/cache.go b/telegram/cache.go index 8aa73893..5e1081d9 100644 --- a/telegram/cache.go +++ b/telegram/cache.go @@ -5,6 +5,7 @@ package telegram import ( "encoding/json" "fmt" + "io" "os" "strconv" "strings" @@ -72,82 +73,72 @@ func NewCache(logLevel string, fileN string) *CACHE { // --------- Cache file Functions --------- func (c *CACHE) WriteFile() { c.Lock() - defer c.Unlock() // necessary? - - if c.file == nil { - var err error - c.file, err = os.OpenFile(c.fileN, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - c.logger.Error("error opening cache file: ", err) - return - } + defer c.Unlock() - defer c.file.Close() + file, err := os.OpenFile(c.fileN, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + c.logger.Error("error opening cache file: ", err) + return } + defer file.Close() - // write format: 'type:id:access_hash,...' - // type: 1 for user, 2 for chat, 3 for channel + var buffer strings.Builder for id, accessHash := range c.InputPeers.InputUsers { - _, _ = c.file.WriteString(fmt.Sprintf("1:%d:%d,", id, accessHash)) + buffer.WriteString(fmt.Sprintf("1:%d:%d,", id, accessHash)) } + for id, accessHash := range c.InputPeers.InputChats { - _, _ = c.file.WriteString(fmt.Sprintf("2:%d:%d,", id, accessHash)) + buffer.WriteString(fmt.Sprintf("2:%d:%d,", id, accessHash)) } + for id, accessHash := range c.InputPeers.InputChannels { - _, _ = c.file.WriteString(fmt.Sprintf("3:%d:%d,", id, accessHash)) + buffer.WriteString(fmt.Sprintf("3:%d:%d,", id, accessHash)) } - c.file = nil + if _, err := file.WriteString(buffer.String()); err != nil { + c.logger.Error("error writing to cache file: ", err) + } } func (c *CACHE) ReadFile() { + // Lock only for cache modification c.Lock() defer c.Unlock() - if c.file == nil { - var err error - c.file, err = os.Open(c.fileN) - if err != nil && !os.IsNotExist(err) { - c.logger.Error("error opening cache file: ", err) - return - } - - defer c.file.Close() + file, err := os.Open(c.fileN) + if err != nil && !os.IsNotExist(err) { + c.logger.Error("error opening cache file: ", err) + return } - - // read till each , using buffer - // format: 'type:id:access_hash,...' - - _totalLoaded := 0 + defer file.Close() buffer := make([]byte, 1) var data []byte + totalLoaded := 0 + for { - _, err := c.file.Read(buffer) + _, err := file.Read(buffer) if err != nil { + if err != io.EOF { + c.logger.Debug("error reading from cache file: ", err) + } break } if buffer[0] == ',' { - // process data - // data format: 'type:id:access_hash' data = append(data, buffer[0]) - // process data if processed := c.processData(data); processed { - _totalLoaded++ + totalLoaded++ } - // reset data data = nil } else { data = append(data, buffer[0]) } } - if _totalLoaded != 0 { - c.logger.Debug("loaded ", _totalLoaded, " peers from cacheFile") + if totalLoaded != 0 { + c.logger.Debug("loaded ", totalLoaded, " peers from cacheFile") } - - c.file = nil } func (c *CACHE) processData(data []byte) bool { @@ -184,20 +175,18 @@ func (c *CACHE) processData(data []byte) bool { } func (c *CACHE) getUserPeer(userID int64) (InputUser, error) { - for userId, accessHash := range c.InputPeers.InputUsers { - if userId == userID { - return &InputUserObj{UserID: userId, AccessHash: accessHash}, nil - } + if userHash, ok := c.InputPeers.InputUsers[userID]; ok { + return &InputUserObj{UserID: userID, AccessHash: userHash}, nil } + return nil, fmt.Errorf("no user with id %d or missing from cache", userID) } func (c *CACHE) getChannelPeer(channelID int64) (InputChannel, error) { - for channelId, channelHash := range c.InputPeers.InputChannels { - if channelId == channelID { - return &InputChannelObj{ChannelID: channelId, AccessHash: channelHash}, nil - } + if channelHash, ok := c.InputPeers.InputChannels[channelID]; ok { + return &InputChannelObj{ChannelID: channelID, AccessHash: channelHash}, nil } + return nil, fmt.Errorf("no channel with id %d or missing from cache", channelID) } @@ -215,21 +204,19 @@ func (c *CACHE) GetInputPeer(peerID int64) (InputPeer, error) { } c.RLock() defer c.RUnlock() - for userId, userHash := range c.InputPeers.InputUsers { - if userId == peerID { - return &InputPeerUser{userId, userHash}, nil - } + + if userHash, ok := c.InputPeers.InputUsers[peerID]; ok { + return &InputPeerUser{peerID, userHash}, nil } - for chatId := range c.InputPeers.InputChats { - if chatId == peerID { - return &InputPeerChat{ChatID: chatId}, nil - } + + if _, ok := c.InputPeers.InputChats[peerID]; ok { + return &InputPeerChat{ChatID: peerID}, nil } - for channelId, channelHash := range c.InputPeers.InputChannels { - if channelId == peerID { - return &InputPeerChannel{channelId, channelHash}, nil - } + + if channelHash, ok := c.InputPeers.InputChannels[peerID]; ok { + return &InputPeerChannel{peerID, channelHash}, nil } + return nil, fmt.Errorf("there is no peer with id %d or missing from cache", peerID) } @@ -238,26 +225,30 @@ func (c *CACHE) GetInputPeer(peerID int64) (InputPeer, error) { func (c *Client) getUserFromCache(userID int64) (*UserObj, error) { c.Cache.RLock() defer c.Cache.RUnlock() - for _, user := range c.Cache.users { - if user.ID == userID { - return user, nil - } + + if user, found := c.Cache.users[userID]; found { + return user, nil } + userPeer, err := c.Cache.getUserPeer(userID) if err != nil { return nil, err } + users, err := c.UsersGetUsers([]InputUser{userPeer}) if err != nil { return nil, err } + if len(users) == 0 { return nil, fmt.Errorf("no user with id %d", userID) } + user, ok := users[0].(*UserObj) if !ok { - return nil, fmt.Errorf("no user with id %d", userID) + return nil, fmt.Errorf("expected UserObj for id %d, but got different type", userID) } + return user, nil } @@ -265,56 +256,64 @@ func (c *Client) getChannelFromCache(channelID int64) (*Channel, error) { c.Cache.RLock() defer c.Cache.RUnlock() - for _, channel := range c.Cache.channels { - if channel.ID == channelID { - return channel, nil - } + if channel, found := c.Cache.channels[channelID]; found { + return channel, nil } + channelPeer, err := c.Cache.getChannelPeer(channelID) if err != nil { return nil, err } + channels, err := c.ChannelsGetChannels([]InputChannel{channelPeer}) if err != nil { return nil, err } + channelsObj, ok := channels.(*MessagesChatsObj) if !ok { - return nil, fmt.Errorf("no channel with id %d or missing from cache", channelID) + return nil, fmt.Errorf("expected MessagesChatsObj for channel id %d, but got different type", channelID) } + if len(channelsObj.Chats) == 0 { - return nil, fmt.Errorf("no channel with id %d or missing from cache", channelID) + return nil, fmt.Errorf("no channel with id %d", channelID) } + channel, ok := channelsObj.Chats[0].(*Channel) if !ok { - return nil, fmt.Errorf("no channel with id %d or missing from cache", channelID) + return nil, fmt.Errorf("expected Channel for id %d, but got different type", channelID) } + return channel, nil } func (c *Client) getChatFromCache(chatID int64) (*ChatObj, error) { c.Cache.RLock() defer c.Cache.RUnlock() - for _, chat := range c.Cache.chats { - if chat.ID == chatID { - return chat, nil - } + + if chat, found := c.Cache.chats[chatID]; found { + return chat, nil } + chat, err := c.MessagesGetChats([]int64{chatID}) if err != nil { return nil, err } + chatsObj, ok := chat.(*MessagesChatsObj) if !ok { - return nil, fmt.Errorf("no chat with id %d or missing from cache", chatID) + return nil, fmt.Errorf("expected MessagesChatsObj for chat id %d, but got different type", chatID) } + if len(chatsObj.Chats) == 0 { - return nil, fmt.Errorf("no chat with id %d or missing from cache", chatID) + return nil, fmt.Errorf("no chat with id %d", chatID) } + chatObj, ok := chatsObj.Chats[0].(*ChatObj) if !ok { - return nil, fmt.Errorf("no chat with id %d or missing from cache", chatID) + return nil, fmt.Errorf("expected ChatObj for id %d, but got different type", chatID) } + return chatObj, nil } @@ -390,6 +389,7 @@ func (c *CACHE) UpdateChannel(channel *Channel) bool { } return false } + c.channels[channel.ID] = channel c.InputPeers.InputChannels[channel.ID] = channel.AccessHash @@ -410,43 +410,39 @@ func (c *CACHE) UpdateChat(chat *ChatObj) bool { return true } -func (cache *CACHE) UpdatePeersToCache(u []User, c []Chat) { - _totalUpdates := [2]int{0, 0} - for _, user := range u { - us, ok := user.(*UserObj) - if ok { - if upd := cache.UpdateUser(us); upd { - _totalUpdates[0]++ +func (cache *CACHE) UpdatePeersToCache(users []User, chats []Chat) { + totalUpdates := [2]int{0, 0} + + for _, user := range users { + if us, ok := user.(*UserObj); ok { + if updated := cache.UpdateUser(us); updated { + totalUpdates[0]++ } } } - for _, chat := range c { - ch, ok := chat.(*ChatObj) - if ok { - if upd := cache.UpdateChat(ch); upd { - _totalUpdates[1]++ + + for _, chat := range chats { + if ch, ok := chat.(*ChatObj); ok { + if updated := cache.UpdateChat(ch); updated { + totalUpdates[1]++ } - } else { - channel, ok := chat.(*Channel) - if ok { - if upd := cache.UpdateChannel(channel); upd { - _totalUpdates[1]++ - } + } else if channel, ok := chat.(*Channel); ok { + if updated := cache.UpdateChannel(channel); updated { + totalUpdates[1]++ } } } - if _totalUpdates[0] != 0 || _totalUpdates[1] != 0 { + if totalUpdates[0] > 0 || totalUpdates[1] > 0 { if cache.writeFile { - go cache.WriteFile() // write to file - } - if _totalUpdates[0] != 0 && _totalUpdates[1] != 0 { - cache.logger.Debug("updated ", _totalUpdates[0], "(u) and ", _totalUpdates[1], "(c) to ", cache.fileN, " (u:", len(cache.InputPeers.InputUsers), ", c:", len(cache.InputPeers.InputChats), ")") - } else if _totalUpdates[0] != 0 { - cache.logger.Debug("updated ", _totalUpdates[0], "(u) to ", cache.fileN, " (u:", len(cache.InputPeers.InputUsers), ", c:", len(cache.InputPeers.InputChats), ")") - } else { - cache.logger.Debug("updated ", _totalUpdates[1], "(c) to ", cache.fileN, " (u:", len(cache.InputPeers.InputUsers), ", c:", len(cache.InputPeers.InputChats), ")") + go cache.WriteFile() // Write to file asynchronously } + cache.logger.Debug( + fmt.Sprintf("updated %d users and %d chats to %s (users: %d, chats: %d)", + totalUpdates[0], totalUpdates[1], cache.fileN, + len(cache.InputPeers.InputUsers), len(cache.InputPeers.InputChats), + ), + ) } } diff --git a/telegram/client.go b/telegram/client.go index 25b49a0a..030deeae 100644 --- a/telegram/client.go +++ b/telegram/client.go @@ -4,6 +4,7 @@ package telegram import ( "crypto/rsa" + "fmt" "log" "net/url" "os" @@ -44,6 +45,7 @@ type clientData struct { type exportedSender struct { client *Client dcID int + added time.Time } type cachedExportedSenders struct { @@ -126,6 +128,8 @@ func NewClient(config ClientConfig) (*Client, error) { if err := client.clientWarnings(config); err != nil { return nil, err } + go client.loopForCleaningExpiredSenders() // start the loop for cleaning expired senders + return client, nil } @@ -328,9 +332,9 @@ func (c *Client) Disconnect() error { } // switchDC permanently switches the data center -func (c *Client) switchDc(dcID int) error { +func (c *Client) SwitchDc(dcID int) error { c.Log.Debug("switching data center to (" + strconv.Itoa(dcID) + ")") - newDcSender, err := c.MTProto.ReconnectToNewDC(dcID) + newDcSender, err := c.MTProto.SwitchDc(dcID) if err != nil { return errors.Wrap(err, "reconnecting to new dc") } @@ -351,22 +355,23 @@ func (c *Client) AddNewExportedSenderToMap(dcID int, sender *Client) { c.exportedSenders.Lock() c.exportedSenders.senders = append( c.exportedSenders.senders, - exportedSender{client: sender, dcID: dcID}, + exportedSender{client: sender, dcID: dcID, added: time.Now()}, ) c.exportedSenders.Unlock() +} - go func() { +func (c *Client) loopForCleaningExpiredSenders() { + for { time.Sleep(DisconnectExportedAfter) c.exportedSenders.Lock() - defer c.exportedSenders.Unlock() - for i, s := range c.exportedSenders.senders { - if s.client == sender { + if time.Since(s.added) > DisconnectExportedAfter { + s.client.Terminate() c.exportedSenders.senders = append(c.exportedSenders.senders[:i], c.exportedSenders.senders[i+1:]...) - break } } - }() // remove the sender from the map after the expiry time + c.exportedSenders.Unlock() + } } func (c *Client) GetCachedExportedSenders(dcID int) []*Client { @@ -391,31 +396,32 @@ func (c *Client) CreateExportedSender(dcID int) (*Client, error) { return nil, errors.Wrap(err, "exporting new sender") } exportedSender := &Client{MTProto: exported, Cache: c.Cache, Log: utils.NewLogger("gogram - sender").SetLevel(c.Log.Lev()), wg: sync.WaitGroup{}, clientData: c.clientData, stopCh: make(chan struct{})} - err = exportedSender.InitialRequest() - if err != nil { - return nil, errors.Wrap(err, "initial request") + + initialReq := &InitConnectionParams{ + ApiID: c.clientData.appID, + DeviceModel: c.clientData.deviceModel, + SystemVersion: c.clientData.systemVersion, + AppVersion: c.clientData.appVersion, + SystemLangCode: c.clientData.langCode, + LangCode: c.clientData.langCode, + Query: &HelpGetConfigParams{}, } if c.MTProto.GetDC() != exported.GetDC() { - if err := exportedSender.shareAuth(c, exportedSender.MTProto.GetDC()); err != nil { - return nil, errors.Wrap(err, "sharing auth") + c.Log.Info(fmt.Sprintf("exporting auth for data-center %d", exported.GetDC())) + auth, err := c.AuthExportAuthorization(int32(exported.GetDC())) + if err != nil { + return nil, errors.Wrap(err, "exporting auth") } - } - c.Log.Debug("exported sender for dc ", exported.GetDC(), " is ready") - return exportedSender, nil -} -// shareAuth shares authorization with another client -func (c *Client) shareAuth(main *Client, dcID int) error { - mainAuth, err := main.AuthExportAuthorization(int32(dcID)) - if err != nil || mainAuth == nil { - return errors.Wrap(err, "exporting authorization") - } - _, err = c.AuthImportAuthorization(mainAuth.ID, mainAuth.Bytes) - if err != nil { - return errors.Wrap(err, "importing authorization") + initialReq.Query = &AuthImportAuthorizationParams{ + ID: auth.ID, + Bytes: auth.Bytes, + } } - return nil + + _, err = exportedSender.InvokeWithLayer(ApiVersion, initialReq) + return exportedSender, err } // cleanExportedSenders terminates all exported senders and removes them from cache @@ -449,7 +455,7 @@ func (c *Client) GetDC() int { // This string can be used to import the session later func (c *Client) ExportSession() string { authSession, dcId := c.MTProto.ExportAuth() - c.Log.Debug("exporting string session...") + c.Log.Debug("exporting auth to string session...") return session.NewStringSession(authSession.Key, authSession.Hash, dcId, authSession.Hostname, authSession.AppID).Encode() } @@ -458,7 +464,7 @@ func (c *Client) ExportSession() string { // Params: // sessionString: The sessionString to authenticate with func (c *Client) ImportSession(sessionString string) (bool, error) { - c.Log.Debug("importing session: ", sessionString) + c.Log.Debug("importing auth from string session...") return c.MTProto.ImportAuth(sessionString) } diff --git a/telegram/helpers.go b/telegram/helpers.go index 1f88b0f3..01872788 100644 --- a/telegram/helpers.go +++ b/telegram/helpers.go @@ -746,7 +746,7 @@ func (c *Client) gatherVideoThumb(path string, duration int64) (InputFile, error // ffmpeg -i input.mp4 -ss 00:00:01.000 -vframes 1 output.png getPosition := func(duration int64) int64 { if duration <= 10 { - return duration + return (duration / 2) + 1 } else { return int64(rand.Int31n(int32(duration)/2) + 1) } diff --git a/telegram/media.go b/telegram/media.go index a1f83fbd..afc6eae7 100644 --- a/telegram/media.go +++ b/telegram/media.go @@ -391,9 +391,6 @@ func (c *Client) DownloadMedia(file interface{}, Opts ...*DownloadOptions) (stri partOver := size % int64(partSize) totalParts := parts - if partOver > 0 { - totalParts++ - } wg := sync.WaitGroup{} numWorkers := countWorkers(parts) @@ -449,45 +446,78 @@ func (c *Client) DownloadMedia(file interface{}, Opts ...*DownloadOptions) (stri found := false for i := 0; i < numWorkers; i++ { if !w[i].buzy && w[i].c != nil { + found = true - part := make([]byte, partSize) w[i].buzy = true - go func(i int, part []byte, p int) { - defer wg.Done() - downloadStartPoint: - c.Logger.Debug(fmt.Sprintf("downloading part %d/%d in chunks of %d", p, totalParts, len(part)/1024)) - buf, err := w[i].c.UploadGetFile(&UploadGetFileParams{ - Location: location, - Offset: int64(p * partSize), - Limit: int32(partSize), - CdnSupported: false, - }) - - if handleIfFlood(err, c) { - goto downloadStartPoint - } + go func(i int, p int) { + defer func() { + w[i].buzy = false + wg.Done() + }() + retryCount := 0 + reqTimeout := 3 * time.Second + + partDownloadStartPoint: + c.Logger.Debug(fmt.Sprintf("download part %d/%d in chunks of %d", p, totalParts, partSize/1024)) + + resultChan := make(chan UploadFile, 1) + errorChan := make(chan error, 1) + + go func() { + upl, err := w[i].c.UploadGetFile(&UploadGetFileParams{ + Location: location, + Offset: int64(p * partSize), + Limit: int32(partSize), + CdnSupported: false, + }) + if err != nil { + errorChan <- err + return + } + resultChan <- upl + }() - if err != nil || buf == nil { - w[i].c.Logger.Warn(err) - return - } - var buffer []byte - switch v := buf.(type) { - case *UploadFileObj: - buffer = v.Bytes - case *UploadFileCdnRedirect: - return // TODO - } - _, err = fs.WriteAt(buffer, int64(p*partSize)) - if err != nil { + select { + case upl := <-resultChan: + if upl == nil { + goto partDownloadStartPoint + } + + var buffer []byte + switch v := upl.(type) { + case *UploadFileObj: + buffer = v.Bytes + case *UploadFileCdnRedirect: + panic("cdn redirect not impl") // TODO + } + + _, err := fs.WriteAt(buffer, int64(p*partSize)) + if err != nil { + c.Logger.Error(err) + } + + if opts.ProgressCallback != nil { + go opts.ProgressCallback(int32(totalParts), int32(p)) + } + w[i].buzy = false + case err := <-errorChan: + if handleIfFlood(err, c) { + goto partDownloadStartPoint + } c.Logger.Error(err) + w[i].buzy = false + case <-time.After(reqTimeout): + c.Logger.Debug(fmt.Errorf("upload part %d timed out - retrying", p)) + retryCount++ + if retryCount > 5 { + c.Logger.Debug(fmt.Errorf("upload part %d timed out - giving up", p)) + return + } else if retryCount > 3 { + reqTimeout = 5 * time.Second + } + goto partDownloadStartPoint } - if opts.ProgressCallback != nil { - go opts.ProgressCallback(int32(totalParts), int32(p)) - } - w[i].buzy = false - }(i, part, int(p)) - break + }(i, int(p)) } } @@ -535,7 +565,7 @@ func (c *Client) DownloadMedia(file interface{}, Opts ...*DownloadOptions) (stri c.Logger.Error(err) } case *UploadFileCdnRedirect: - return "", nil // TODO + panic("cdn redirect not impl") // TODO } if opts.ProgressCallback != nil {