diff --git a/telegram/cache.go b/telegram/cache.go index 5e1081d9..7ee50808 100644 --- a/telegram/cache.go +++ b/telegram/cache.go @@ -21,7 +21,6 @@ type CACHE struct { users map[int64]*UserObj channels map[int64]*Channel writeFile bool - file *os.File InputPeers *InputPeerCache `json:"input_peers,omitempty"` logger *utils.Logger } @@ -72,9 +71,6 @@ func NewCache(logLevel string, fileN string) *CACHE { // --------- Cache file Functions --------- func (c *CACHE) WriteFile() { - c.Lock() - defer c.Unlock() - file, err := os.OpenFile(c.fileN, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { c.logger.Error("error opening cache file: ", err) @@ -84,6 +80,7 @@ func (c *CACHE) WriteFile() { var buffer strings.Builder + c.Lock() for id, accessHash := range c.InputPeers.InputUsers { buffer.WriteString(fmt.Sprintf("1:%d:%d,", id, accessHash)) } @@ -95,6 +92,7 @@ func (c *CACHE) WriteFile() { for id, accessHash := range c.InputPeers.InputChannels { buffer.WriteString(fmt.Sprintf("3:%d:%d,", id, accessHash)) } + c.Unlock() if _, err := file.WriteString(buffer.String()); err != nil { c.logger.Error("error writing to cache file: ", err) @@ -102,10 +100,6 @@ func (c *CACHE) WriteFile() { } func (c *CACHE) ReadFile() { - // Lock only for cache modification - c.Lock() - defer c.Unlock() - file, err := os.Open(c.fileN) if err != nil && !os.IsNotExist(err) { c.logger.Error("error opening cache file: ", err) @@ -116,7 +110,6 @@ func (c *CACHE) ReadFile() { buffer := make([]byte, 1) var data []byte totalLoaded := 0 - for { _, err := file.Read(buffer) if err != nil { @@ -160,6 +153,8 @@ func (c *CACHE) processData(data []byte) bool { } // process data + c.Lock() + defer c.Unlock() switch splitData[0] { case "1": c.InputPeers.InputUsers[int64(id)] = int64(accessHash) @@ -175,6 +170,9 @@ func (c *CACHE) processData(data []byte) bool { } func (c *CACHE) getUserPeer(userID int64) (InputUser, error) { + c.RLock() + defer c.RUnlock() + if userHash, ok := c.InputPeers.InputUsers[userID]; ok { return &InputUserObj{UserID: userID, AccessHash: userHash}, nil } @@ -183,6 +181,9 @@ func (c *CACHE) getUserPeer(userID int64) (InputUser, error) { } func (c *CACHE) getChannelPeer(channelID int64) (InputChannel, error) { + c.RLock() + defer c.RUnlock() + if channelHash, ok := c.InputPeers.InputChannels[channelID]; ok { return &InputChannelObj{ChannelID: channelID, AccessHash: channelHash}, nil } @@ -224,11 +225,11 @@ func (c *CACHE) GetInputPeer(peerID int64) (InputPeer, error) { func (c *Client) getUserFromCache(userID int64) (*UserObj, error) { c.Cache.RLock() - defer c.Cache.RUnlock() - if user, found := c.Cache.users[userID]; found { + c.Cache.RUnlock() return user, nil } + c.Cache.RUnlock() userPeer, err := c.Cache.getUserPeer(userID) if err != nil { @@ -254,11 +255,11 @@ func (c *Client) getUserFromCache(userID int64) (*UserObj, error) { func (c *Client) getChannelFromCache(channelID int64) (*Channel, error) { c.Cache.RLock() - defer c.Cache.RUnlock() - if channel, found := c.Cache.channels[channelID]; found { + c.Cache.RUnlock() return channel, nil } + c.Cache.RUnlock() channelPeer, err := c.Cache.getChannelPeer(channelID) if err != nil { @@ -289,11 +290,11 @@ func (c *Client) getChannelFromCache(channelID int64) (*Channel, error) { func (c *Client) getChatFromCache(chatID int64) (*ChatObj, error) { c.Cache.RLock() - defer c.Cache.RUnlock() - if chat, found := c.Cache.chats[chatID]; found { + c.Cache.RUnlock() return chat, nil } + c.Cache.RUnlock() chat, err := c.MessagesGetChats([]int64{chatID}) if err != nil { @@ -447,6 +448,9 @@ func (cache *CACHE) UpdatePeersToCache(users []User, chats []Chat) { } func (c *Client) GetPeerUser(userID int64) (*InputPeerUser, error) { + c.Cache.RLock() + defer c.Cache.RUnlock() + if peer, ok := c.Cache.InputPeers.InputUsers[userID]; ok { return &InputPeerUser{UserID: userID, AccessHash: peer}, nil } @@ -454,6 +458,8 @@ func (c *Client) GetPeerUser(userID int64) (*InputPeerUser, error) { } func (c *Client) GetPeerChannel(channelID int64) (*InputPeerChannel, error) { + c.Cache.RLock() + defer c.Cache.RUnlock() if peer, ok := c.Cache.InputPeers.InputChannels[channelID]; ok { return &InputPeerChannel{ChannelID: channelID, AccessHash: peer}, nil @@ -462,6 +468,9 @@ func (c *Client) GetPeerChannel(channelID int64) (*InputPeerChannel, error) { } func (c *Client) IdInCache(id int64) bool { + c.Cache.RLock() + defer c.Cache.RUnlock() + _, ok := c.Cache.InputPeers.InputUsers[id] if ok { return true diff --git a/telegram/client.go b/telegram/client.go index 030deeae..4f5fd67f 100644 --- a/telegram/client.go +++ b/telegram/client.go @@ -395,7 +395,8 @@ func (c *Client) CreateExportedSender(dcID int) (*Client, error) { if err != nil { return nil, errors.Wrap(err, "exporting new sender") } - exportedSender := &Client{MTProto: exported, Cache: c.Cache, Log: utils.NewLogger("gogram - sender").SetLevel(c.Log.Lev()), wg: sync.WaitGroup{}, clientData: c.clientData, stopCh: make(chan struct{})} + + exportedSender := &Client{MTProto: exported, Cache: NewCache(LogDisable, ""), Log: utils.NewLogger("gogram - sender").SetLevel(c.Log.Lev()), wg: sync.WaitGroup{}, clientData: c.clientData, stopCh: make(chan struct{})} initialReq := &InitConnectionParams{ ApiID: c.clientData.appID, diff --git a/telegram/media.go b/telegram/media.go index afc6eae7..83540dfa 100644 --- a/telegram/media.go +++ b/telegram/media.go @@ -303,13 +303,13 @@ func countWorkers(parts int64) int { } else if parts <= 50 { return 3 } else if parts <= 100 { - return 5 - } else if parts <= 200 { return 6 - } else if parts <= 400 { + } else if parts <= 200 { return 7 - } else if parts <= 500 { + } else if parts <= 400 { return 8 + } else if parts <= 500 { + return 10 } else { return 12 // not recommended to use more than 15 workers } @@ -392,7 +392,6 @@ func (c *Client) DownloadMedia(file interface{}, Opts ...*DownloadOptions) (stri totalParts := parts - wg := sync.WaitGroup{} numWorkers := countWorkers(parts) if opts.Threads > 0 { numWorkers = opts.Threads @@ -440,20 +439,36 @@ func (c *Client) DownloadMedia(file interface{}, Opts ...*DownloadOptions) (stri } }() + mu := sync.Mutex{} + wg := sync.WaitGroup{} + for p := int64(0); p < parts; p++ { wg.Add(1) - for { - found := false - for i := 0; i < numWorkers; i++ { - if !w[i].buzy && w[i].c != nil { + go func(p int64) { + defer wg.Done() + + for { + mu.Lock() + found := false + var workerIndex int + for i := 0; i < numWorkers; i++ { + if !w[i].buzy && w[i].c != nil { + found = true + w[i].buzy = true + workerIndex = i + break + } + } + mu.Unlock() - found = true - w[i].buzy = true + if found { go func(i int, p int) { defer func() { + mu.Lock() w[i].buzy = false - wg.Done() + mu.Unlock() }() + retryCount := 0 reqTimeout := 3 * time.Second @@ -499,32 +514,27 @@ func (c *Client) DownloadMedia(file interface{}, Opts ...*DownloadOptions) (stri if opts.ProgressCallback != nil { go opts.ProgressCallback(int32(totalParts), int32(p)) } - w[i].buzy = false case err := <-errorChan: if handleIfFlood(err, c) { goto partDownloadStartPoint } c.Logger.Error(err) - w[i].buzy = false case <-time.After(reqTimeout): c.Logger.Debug(fmt.Errorf("upload part %d timed out - retrying", p)) retryCount++ - if retryCount > 5 { + if retryCount > 3 { c.Logger.Debug(fmt.Errorf("upload part %d timed out - giving up", p)) return - } else if retryCount > 3 { + } else if retryCount > 2 { reqTimeout = 5 * time.Second } goto partDownloadStartPoint } - }(i, int(p)) + }(workerIndex, int(p)) + break } } - - if found { - break - } - } + }(p) } wg.Wait()