-
Notifications
You must be signed in to change notification settings - Fork 0
/
communicationhandler.go
320 lines (265 loc) · 7.56 KB
/
communicationhandler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
// Package o3 central communication unit responsible for complete exchanges (like handshake and subsequent
// message reception). Uses functions in packethandler and packetdispatcher to deal with incoming
// and outgoing messages. Errors in underlying functions bubble up as panics and have to be re-
// covered here, converted to go errors and returned.
package o3
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"
)
func receiveHelper(reader io.Reader, n int) *bytes.Buffer {
buf := make([]byte, n)
//TODO handle number of received bytes?
_, err := reader.Read(buf)
if err != nil {
panic(err)
}
return bytes.NewBuffer(buf)
}
// ReceivedMsg is a type used to transmit messages via a channel
type ReceivedMsg struct {
Msg Message
Err error
}
// preflightCheck quickly tests the ID and returns an error if it's empty
func (sc *SessionContext) preflightCheck() error {
check := false
for _, b := range sc.ID.ID {
if b != 0x0 {
check = true
break
}
}
if !check {
return errors.New("cannot connect using empty ID")
}
check = false
for _, b := range sc.ID.LSK {
if b != 0x0 {
check = true
break
}
}
if !check {
return errors.New("cannot connect using empty secret key")
}
return nil
}
// Run receives all enqueued Messages and writes the results
// to the channel passed as argument
func (sc *SessionContext) Run() (chan<- Message, <-chan ReceivedMsg, error) {
defer func() {
if r := recover(); r != nil {
// TODO: Return the error
handlerPanicHandler("Receive Messages", r)
}
}()
//check if we have an ID and LSK to work with
if err := sc.preflightCheck(); err != nil {
return nil, nil, err
}
var err error
sc.connection, err = net.Dial("tcp", "g-33.0.threema.ch:5222")
if err != nil {
return nil, nil, err
}
//handshake
//Info.Println("Initiating Handshake")
sc.dispatchClientHello(sc.connection)
sc.handleServerHello(receiveHelper(sc.connection, 80))
sc.dispatchAuthMsg(sc.connection)
sc.handleHandshakeAck(receiveHelper(sc.connection, 32))
//Info.Println("Handshake Completed")
//TODO: find better way to handle large amounts of offline messages
sc.sendMsgChan = make(chan Message, 1000)
sc.receiveMsgChan = make(chan ReceivedMsg, 1000)
// receiveLoop calls sendLoop when ready
go sc.receiveLoop()
return sc.sendMsgChan, sc.receiveMsgChan, nil
}
func (sc *SessionContext) receiveLoop() {
defer sc.connection.Close()
//recv:
for {
pktIntf, err := sc.receivePacket(sc.connection)
if err != nil {
if err == io.EOF {
//break recv
return
}
//Error.Printf("receivePacket failed: %s", err)
sc.receiveMsgChan <- ReceivedMsg{
Msg: nil,
Err: err,
}
// TODO: break/return on specific errors - i.e. connection reset
continue
}
switch pkt := pktIntf.(type) {
case messagePacket:
// Acknowledge message packet
sc.dispatchAckMsg(sc.connection, pkt)
// Get the actual message
var rmsg ReceivedMsg
rmsg.Msg, rmsg.Err = sc.handleMessagePacket(pkt)
sc.receiveMsgChan <- rmsg
case ackPacket:
// ok cool. nothing to do.
case echoPacket:
sc.echoCounter = pkt.Counter
case connEstPacket:
//Info.Printf("Got Message: %#v\n", pkt)
go sc.sendLoop()
default:
fmt.Printf("ReceiveMessages: unhandled packet type: %T", pkt)
return
}
// TODO: Implement the echo ping pong
}
}
func (sc *SessionContext) sendLoop() {
// Write a new echo pkt to the echPktChan every 3 minutes
echoPktChan := make(chan echoPacket)
go func() {
timeChan := time.Tick(3 * time.Minute)
for range timeChan {
ep := echoPacket{PktType: echoMsg,
Counter: sc.echoCounter}
echoPktChan <- ep
}
}()
for {
select {
case msg := <-sc.sendMsgChan:
sc.dispatchMessage(sc.connection, msg)
// Read from echo channel and dispatch (happens every 3 min)
case echoPkt := <-echoPktChan:
sc.dispatchEchoMsg(sc.connection, echoPkt)
}
}
}
// SendTextMessage sends a Text Message to the specified ID
// Enqueued messages will be received, not acknowledged and discarded
func (sc *SessionContext) SendTextMessage(recipient string, text string, sendMsgChan chan<- Message) error {
// build a message
tm, err := NewTextMessage(sc, recipient, text)
// TODO: error handling
if err != nil {
return err
}
sendMsgChan <- tm
return nil
}
// SendImageMessage sends a Image Message to the specified ID
// Enqueued messages will be received, not acknowledged and discarded
func (sc *SessionContext) SendImageMessage(recipient string, filename string, sendMsgChan chan<- Message) error {
// build a message
im, err := NewImageMessage(sc, recipient, filename)
if err != nil {
return err
}
sendMsgChan <- im
return nil
}
// SendAudioMessage sends a Audio Message to the specified ID
// Enqueued messages will be received, not acknowledged and discarded
// Works with various audio formats threema uses some kind of mp4 but mp3 works fine
func (sc *SessionContext) SendAudioMessage(recipient string, filename string, sendMsgChan chan<- Message) error {
// build a message
am, err := NewAudioMessage(sc, recipient, filename)
if err != nil {
return err
}
sendMsgChan <- am
return nil
}
// SendGroupTextMessage Sends a text message to all members
func (sc *SessionContext) SendGroupTextMessage(group Group, text string, sendMsgChan chan<- Message) (err error) {
tms, err := NewGroupTextMessages(sc, group, text)
if err != nil {
return err
}
for _, msg := range tms {
sendMsgChan <- msg
}
return nil
}
// CreateNewGroup Creates a new group and notifies all members
func (sc *SessionContext) CreateNewGroup(group Group, sendMsgChan chan<- Message) (groupID [8]byte, err error) {
group.GroupID = NewGrpID()
sc.ChangeGroupMembers(group, sendMsgChan)
if err != nil {
return groupID, err
}
sc.RenameGroup(group, sendMsgChan)
if err != nil {
return groupID, err
}
return groupID, nil
}
// RenameGroup Sends a message with the new group name to all members
func (sc *SessionContext) RenameGroup(group Group, sendMsgChan chan<- Message) (err error) {
sgn := NewGroupManageSetNameMessages(sc, group)
for _, msg := range sgn {
sendMsgChan <- msg
}
return nil
}
// ChangeGroupMembers Sends a message with the new group member list to all members
func (sc *SessionContext) ChangeGroupMembers(group Group, sendMsgChan chan<- Message) (err error) {
sgm := NewGroupManageSetMembersMessages(sc, group)
for _, msg := range sgm {
sendMsgChan <- msg
}
return nil
}
// LeaveGroup Sends a message to all members telling them the sender left the group
func (sc *SessionContext) LeaveGroup(group Group, sendMsgChan chan<- Message) (err error) {
sgm := NewGroupMemberLeftMessages(sc, group)
for _, msg := range sgm {
sendMsgChan <- msg
}
return nil
}
func (sc *SessionContext) receivePacket(reader io.Reader) (pkt interface{}, err error) {
defer func() {
if r := recover(); r != nil {
err = handlerPanicHandler("receivePacket", r)
}
}()
length, err := receivePacketLength(reader)
if err != nil {
return nil, err
}
buf := make([]byte, length)
n, err := reader.Read(buf)
if n != int(length) {
if err != nil {
return nil, err
}
return nil, fmt.Errorf("packet of invalid length received. Expected: %d; received: %d", length, n)
}
pkt = sc.handleClientServerMsg(bytes.NewBuffer(buf))
return pkt, nil
}
func receivePacketLength(reader io.Reader) (uint16, error) {
lbuf := make([]byte, 2)
var length uint16
n, err := reader.Read(lbuf)
if n != 2 {
if err != nil {
return 0, err
}
return 0, fmt.Errorf("no parseable packet length received")
}
err = binary.Read(bytes.NewBuffer(lbuf), binary.LittleEndian, &length)
if err != nil {
return 0, err
}
return length, nil
}