Skip to content

Commit

Permalink
#130, #129, download-media add lock for sender buzy field, cache lock…
Browse files Browse the repository at this point in the history
…s complex everywhere 😁
  • Loading branch information
AmarnathCJD committed Aug 6, 2024
1 parent a139b83 commit 74c4115
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 38 deletions.
39 changes: 24 additions & 15 deletions telegram/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ type CACHE struct {
users map[int64]*UserObj
channels map[int64]*Channel
writeFile bool
file *os.File
InputPeers *InputPeerCache `json:"input_peers,omitempty"`
logger *utils.Logger
}
Expand Down Expand Up @@ -72,9 +71,6 @@ func NewCache(logLevel string, fileN string) *CACHE {

// --------- Cache file Functions ---------
func (c *CACHE) WriteFile() {
c.Lock()
defer c.Unlock()

file, err := os.OpenFile(c.fileN, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
c.logger.Error("error opening cache file: ", err)
Expand All @@ -84,6 +80,7 @@ func (c *CACHE) WriteFile() {

var buffer strings.Builder

c.Lock()
for id, accessHash := range c.InputPeers.InputUsers {
buffer.WriteString(fmt.Sprintf("1:%d:%d,", id, accessHash))
}
Expand All @@ -95,17 +92,14 @@ func (c *CACHE) WriteFile() {
for id, accessHash := range c.InputPeers.InputChannels {
buffer.WriteString(fmt.Sprintf("3:%d:%d,", id, accessHash))
}
c.Unlock()

if _, err := file.WriteString(buffer.String()); err != nil {
c.logger.Error("error writing to cache file: ", err)
}
}

func (c *CACHE) ReadFile() {
// Lock only for cache modification
c.Lock()
defer c.Unlock()

file, err := os.Open(c.fileN)
if err != nil && !os.IsNotExist(err) {
c.logger.Error("error opening cache file: ", err)
Expand All @@ -116,7 +110,6 @@ func (c *CACHE) ReadFile() {
buffer := make([]byte, 1)
var data []byte
totalLoaded := 0

for {
_, err := file.Read(buffer)
if err != nil {
Expand Down Expand Up @@ -160,6 +153,8 @@ func (c *CACHE) processData(data []byte) bool {
}

// process data
c.Lock()
defer c.Unlock()
switch splitData[0] {
case "1":
c.InputPeers.InputUsers[int64(id)] = int64(accessHash)
Expand All @@ -175,6 +170,9 @@ func (c *CACHE) processData(data []byte) bool {
}

func (c *CACHE) getUserPeer(userID int64) (InputUser, error) {
c.RLock()
defer c.RUnlock()

if userHash, ok := c.InputPeers.InputUsers[userID]; ok {
return &InputUserObj{UserID: userID, AccessHash: userHash}, nil
}
Expand All @@ -183,6 +181,9 @@ func (c *CACHE) getUserPeer(userID int64) (InputUser, error) {
}

func (c *CACHE) getChannelPeer(channelID int64) (InputChannel, error) {
c.RLock()
defer c.RUnlock()

if channelHash, ok := c.InputPeers.InputChannels[channelID]; ok {
return &InputChannelObj{ChannelID: channelID, AccessHash: channelHash}, nil
}
Expand Down Expand Up @@ -224,11 +225,11 @@ func (c *CACHE) GetInputPeer(peerID int64) (InputPeer, error) {

func (c *Client) getUserFromCache(userID int64) (*UserObj, error) {
c.Cache.RLock()
defer c.Cache.RUnlock()

if user, found := c.Cache.users[userID]; found {
c.Cache.RUnlock()
return user, nil
}
c.Cache.RUnlock()

userPeer, err := c.Cache.getUserPeer(userID)
if err != nil {
Expand All @@ -254,11 +255,11 @@ func (c *Client) getUserFromCache(userID int64) (*UserObj, error) {

func (c *Client) getChannelFromCache(channelID int64) (*Channel, error) {
c.Cache.RLock()
defer c.Cache.RUnlock()

if channel, found := c.Cache.channels[channelID]; found {
c.Cache.RUnlock()
return channel, nil
}
c.Cache.RUnlock()

channelPeer, err := c.Cache.getChannelPeer(channelID)
if err != nil {
Expand Down Expand Up @@ -289,11 +290,11 @@ func (c *Client) getChannelFromCache(channelID int64) (*Channel, error) {

func (c *Client) getChatFromCache(chatID int64) (*ChatObj, error) {
c.Cache.RLock()
defer c.Cache.RUnlock()

if chat, found := c.Cache.chats[chatID]; found {
c.Cache.RUnlock()
return chat, nil
}
c.Cache.RUnlock()

chat, err := c.MessagesGetChats([]int64{chatID})
if err != nil {
Expand Down Expand Up @@ -447,13 +448,18 @@ func (cache *CACHE) UpdatePeersToCache(users []User, chats []Chat) {
}

func (c *Client) GetPeerUser(userID int64) (*InputPeerUser, error) {
c.Cache.RLock()
defer c.Cache.RUnlock()

if peer, ok := c.Cache.InputPeers.InputUsers[userID]; ok {
return &InputPeerUser{UserID: userID, AccessHash: peer}, nil
}
return nil, fmt.Errorf("no user with id %d or missing from cache", userID)
}

func (c *Client) GetPeerChannel(channelID int64) (*InputPeerChannel, error) {
c.Cache.RLock()
defer c.Cache.RUnlock()

if peer, ok := c.Cache.InputPeers.InputChannels[channelID]; ok {
return &InputPeerChannel{ChannelID: channelID, AccessHash: peer}, nil
Expand All @@ -462,6 +468,9 @@ func (c *Client) GetPeerChannel(channelID int64) (*InputPeerChannel, error) {
}

func (c *Client) IdInCache(id int64) bool {
c.Cache.RLock()
defer c.Cache.RUnlock()

_, ok := c.Cache.InputPeers.InputUsers[id]
if ok {
return true
Expand Down
3 changes: 2 additions & 1 deletion telegram/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ func (c *Client) CreateExportedSender(dcID int) (*Client, error) {
if err != nil {
return nil, errors.Wrap(err, "exporting new sender")
}
exportedSender := &Client{MTProto: exported, Cache: c.Cache, Log: utils.NewLogger("gogram - sender").SetLevel(c.Log.Lev()), wg: sync.WaitGroup{}, clientData: c.clientData, stopCh: make(chan struct{})}

exportedSender := &Client{MTProto: exported, Cache: NewCache(LogDisable, ""), Log: utils.NewLogger("gogram - sender").SetLevel(c.Log.Lev()), wg: sync.WaitGroup{}, clientData: c.clientData, stopCh: make(chan struct{})}

initialReq := &InitConnectionParams{
ApiID: c.clientData.appID,
Expand Down
54 changes: 32 additions & 22 deletions telegram/media.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,13 @@ func countWorkers(parts int64) int {
} else if parts <= 50 {
return 3
} else if parts <= 100 {
return 5
} else if parts <= 200 {
return 6
} else if parts <= 400 {
} else if parts <= 200 {
return 7
} else if parts <= 500 {
} else if parts <= 400 {
return 8
} else if parts <= 500 {
return 10
} else {
return 12 // not recommended to use more than 15 workers
}
Expand Down Expand Up @@ -392,7 +392,6 @@ func (c *Client) DownloadMedia(file interface{}, Opts ...*DownloadOptions) (stri

totalParts := parts

wg := sync.WaitGroup{}
numWorkers := countWorkers(parts)
if opts.Threads > 0 {
numWorkers = opts.Threads
Expand Down Expand Up @@ -440,20 +439,36 @@ func (c *Client) DownloadMedia(file interface{}, Opts ...*DownloadOptions) (stri
}
}()

mu := sync.Mutex{}
wg := sync.WaitGroup{}

for p := int64(0); p < parts; p++ {
wg.Add(1)
for {
found := false
for i := 0; i < numWorkers; i++ {
if !w[i].buzy && w[i].c != nil {
go func(p int64) {
defer wg.Done()

for {
mu.Lock()
found := false
var workerIndex int
for i := 0; i < numWorkers; i++ {
if !w[i].buzy && w[i].c != nil {
found = true
w[i].buzy = true
workerIndex = i
break
}
}
mu.Unlock()

found = true
w[i].buzy = true
if found {
go func(i int, p int) {
defer func() {
mu.Lock()
w[i].buzy = false
wg.Done()
mu.Unlock()
}()

retryCount := 0
reqTimeout := 3 * time.Second

Expand Down Expand Up @@ -499,32 +514,27 @@ func (c *Client) DownloadMedia(file interface{}, Opts ...*DownloadOptions) (stri
if opts.ProgressCallback != nil {
go opts.ProgressCallback(int32(totalParts), int32(p))
}
w[i].buzy = false
case err := <-errorChan:
if handleIfFlood(err, c) {
goto partDownloadStartPoint
}
c.Logger.Error(err)
w[i].buzy = false
case <-time.After(reqTimeout):
c.Logger.Debug(fmt.Errorf("upload part %d timed out - retrying", p))
retryCount++
if retryCount > 5 {
if retryCount > 3 {
c.Logger.Debug(fmt.Errorf("upload part %d timed out - giving up", p))
return
} else if retryCount > 3 {
} else if retryCount > 2 {
reqTimeout = 5 * time.Second
}
goto partDownloadStartPoint
}
}(i, int(p))
}(workerIndex, int(p))
break
}
}

if found {
break
}
}
}(p)
}

wg.Wait()
Expand Down

0 comments on commit 74c4115

Please sign in to comment.