diff --git a/conf/lalserver.conf.json b/conf/lalserver.conf.json index 208aa3a5..d0cc09fa 100644 --- a/conf/lalserver.conf.json +++ b/conf/lalserver.conf.json @@ -60,7 +60,9 @@ "auth_enable": false, "auth_method": 1, "username": "q191201771", - "password": "pengrl" + "password": "pengrl", + "ws_rtsp_enable": true, + "ws_rtsp_addr": ":5566" }, "record": { "enable_flv": false, diff --git a/pkg/base/basic_http_sub_session.go b/pkg/base/basic_http_sub_session.go index b3f3d7b6..338bb9a5 100644 --- a/pkg/base/basic_http_sub_session.go +++ b/pkg/base/basic_http_sub_session.go @@ -59,7 +59,7 @@ func (session *BasicHttpSubSession) Dispose() error { func (session *BasicHttpSubSession) WriteHttpResponseHeader(b []byte) { if session.IsWebSocket { - session.write(UpdateWebSocketHeader(session.WebSocketKey)) + session.write(UpdateWebSocketHeader(session.WebSocketKey, "")) } else { session.write(b) } diff --git a/pkg/base/websocket.go b/pkg/base/websocket.go index 9ffaafbf..b3c4cc2f 100644 --- a/pkg/base/websocket.go +++ b/pkg/base/websocket.go @@ -9,8 +9,12 @@ package base import ( + "bufio" "crypto/sha1" "encoding/base64" + "encoding/binary" + "fmt" + "io" "math" "github.com/q191201771/naza/pkg/bele" @@ -140,17 +144,151 @@ func MakeWsFrameHeader(wsHeader WsHeader) (buf []byte) { } return buf } -func UpdateWebSocketHeader(secWebSocketKey string) []byte { + +func UpdateWebSocketHeader(secWebSocketKey, protocol string) []byte { firstLine := "HTTP/1.1 101 Switching Protocol\r\n" sha1Sum := sha1.Sum([]byte(secWebSocketKey + WsMagicStr)) secWebSocketAccept := base64.StdEncoding.EncodeToString(sha1Sum[:]) - webSocketResponseHeaderStr := firstLine + - "Server: " + LalHttpflvSubSessionServer + "\r\n" + - "Sec-WebSocket-Accept:" + secWebSocketAccept + "\r\n" + - "Keep-Alive: timeout=15, max=100\r\n" + - "Connection: Upgrade\r\n" + - "Upgrade: websocket\r\n" + - CorsHeaders + - "\r\n" + + var webSocketResponseHeaderStr string + if protocol == "" { + webSocketResponseHeaderStr = firstLine + + "Server: " + LalHttpflvSubSessionServer + "\r\n" + + "Sec-WebSocket-Accept:" + secWebSocketAccept + "\r\n" + + "Keep-Alive: timeout=15, max=100\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: websocket\r\n" + + CorsHeaders + + "\r\n" + } else { + webSocketResponseHeaderStr = firstLine + + "Server: " + LalHttpflvSubSessionServer + "\r\n" + + "Sec-WebSocket-Accept:" + secWebSocketAccept + "\r\n" + + "Keep-Alive: timeout=15, max=100\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: websocket\r\n" + + CorsHeaders + + "Sec-WebSocket-Protocol:" + protocol + "\r\n" + + "\r\n" + } return []byte(webSocketResponseHeaderStr) } + +func ReadWsPayload(r *bufio.Reader) ([]byte, error) { + var h WsHeader + + buf := make([]byte, 2) + _, err := io.ReadFull(r, buf) + if err != nil { + return nil, err + } + + h.Fin = (buf[0] & 0x80) != 0 + h.Rsv1 = (buf[0] & 0x40) != 0 + h.Rsv2 = (buf[0] & 0x20) != 0 + h.Rsv3 = (buf[0] & 0x10) != 0 + h.Opcode = buf[0] & 0x0f + + if buf[1]&0x80 != 0 { + h.Masked = true + } + + length := buf[1] & 0x7f + switch { + case length < 126: + h.PayloadLength = uint64(length) + case length == 126: + buf = make([]byte, 2) + _, err := io.ReadFull(r, buf) + if err != nil { + return nil, err + } + + h.PayloadLength = uint64(binary.BigEndian.Uint16(buf)) + case length == 127: + buf = make([]byte, 8) + _, err := io.ReadFull(r, buf) + if err != nil { + return nil, err + } + + h.PayloadLength = binary.BigEndian.Uint64(buf) + + default: + err = fmt.Errorf("header error: the most significant bit must be 0") + return nil, err + } + + if h.Masked { + buf = make([]byte, 4) + _, err := io.ReadFull(r, buf) + if err != nil { + return nil, err + } + + h.MaskKey = bele.BeUint32(buf) + } + + payload := make([]byte, h.PayloadLength) + _, err = io.ReadFull(r, payload) + if err != nil { + return nil, err + } + + if h.Masked { + mask := make([]byte, 4) + binary.BigEndian.PutUint32(mask, h.MaskKey) + cipher(payload, mask, 0) + } + + return payload, nil +} + +func cipher(payload []byte, mask []byte, offset int) { + n := len(payload) + if n < 8 { + for i := 0; i < n; i++ { + payload[i] ^= mask[(offset+i)%4] + } + return + } + + // Calculate position in mask due to previously processed bytes number. + mpos := offset % 4 + // Count number of bytes will processed one by one from the beginning of payload. + ln := remain[mpos] + // Count number of bytes will processed one by one from the end of payload. + // This is done to process payload by 8 bytes in each iteration of main loop. + rn := (n - ln) % 8 + + for i := 0; i < ln; i++ { + payload[i] ^= mask[(mpos+i)%4] + } + for i := n - rn; i < n; i++ { + payload[i] ^= mask[(mpos+i)%4] + } + + // NOTE: we use here binary.LittleEndian regardless of what is real + // endianness on machine is. To do so, we have to use binary.LittleEndian in + // the masking loop below as well. + var ( + m = binary.LittleEndian.Uint32((mask[:])) + m2 = uint64(m)<<32 | uint64(m) + ) + // Skip already processed right part. + // Get number of uint64 parts remaining to process. + n = (n - ln - rn) >> 3 + for i := 0; i < n; i++ { + var ( + j = ln + (i << 3) + chunk = payload[j : j+8] + ) + p := binary.LittleEndian.Uint64(chunk) + p = p ^ m2 + binary.LittleEndian.PutUint64(chunk, p) + } +} + +// remain maps position in masking key [0,4) to number +// of bytes that need to be processed manually inside Cipher(). +var remain = [4]int{0, 3, 2, 1} diff --git a/pkg/logic/config.go b/pkg/logic/config.go index 4c6d61da..67b5f093 100644 --- a/pkg/logic/config.go +++ b/pkg/logic/config.go @@ -102,6 +102,8 @@ type RtspConfig struct { RtspsCertFile string `json:"rtsps_cert_file"` RtspsKeyFile string `json:"rtsps_key_file"` OutWaitKeyFrameFlag bool `json:"out_wait_key_frame_flag"` + WsRtspEnable bool `json:"ws_rtsp_enable"` + WsRtspAddr string `json:"ws_rtsp_addr"` rtsp.ServerAuthConfig } diff --git a/pkg/logic/server_manager__.go b/pkg/logic/server_manager__.go index e42dac5e..9dc6e994 100644 --- a/pkg/logic/server_manager__.go +++ b/pkg/logic/server_manager__.go @@ -11,7 +11,6 @@ package logic import ( "flag" "fmt" - "github.com/q191201771/naza/pkg/taskpool" "net/http" _ "net/http/pprof" "os" @@ -19,6 +18,8 @@ import ( "sync" "time" + "github.com/q191201771/naza/pkg/taskpool" + "github.com/q191201771/lal/pkg/base" "github.com/q191201771/lal/pkg/hls" "github.com/q191201771/lal/pkg/httpflv" @@ -45,6 +46,7 @@ type ServerManager struct { rtspsServer *rtsp.Server httpApiServer *HttpApiServer pprofServer *http.Server + wsrtspServer *rtsp.WebsocketServer exitChan chan struct{} mutex sync.Mutex @@ -139,6 +141,9 @@ Doc: %s if sm.config.RtspConfig.RtspsEnable { sm.rtspsServer = rtsp.NewServer(sm.config.RtspConfig.RtspsAddr, sm, sm.config.RtspConfig.ServerAuthConfig) } + if sm.config.RtspConfig.WsRtspEnable { + sm.wsrtspServer = rtsp.NewWebsocketServer(sm.config.RtspConfig.WsRtspAddr, sm, sm.config.RtspConfig.ServerAuthConfig) + } if sm.config.HttpApiConfig.Enable { sm.httpApiServer = NewHttpApiServer(sm.config.HttpApiConfig.Addr, sm) } @@ -268,6 +273,15 @@ func (sm *ServerManager) RunLoop() error { } } + if sm.wsrtspServer != nil { + go func() { + err := sm.wsrtspServer.Listen() + if err != nil { + Log.Error(err) + } + }() + } + if sm.httpApiServer != nil { if err := sm.httpApiServer.Listen(); err != nil { return err diff --git a/pkg/rtsp/server.go b/pkg/rtsp/server.go index a235fb94..3318a78f 100644 --- a/pkg/rtsp/server.go +++ b/pkg/rtsp/server.go @@ -140,7 +140,7 @@ func (s *Server) OnDelRtspSubSession(session *SubSession) { // --------------------------------------------------------------------------------------------------------------------- func (s *Server) handleTcpConnect(conn net.Conn) { - session := NewServerCommandSession(s, conn, s.auth) + session := NewServerCommandSession(s, conn, s.auth, false, "") s.observer.OnNewRtspSessionConnect(session) err := session.RunLoop() diff --git a/pkg/rtsp/server_command_session.go b/pkg/rtsp/server_command_session.go index e10316fd..8125a0e4 100644 --- a/pkg/rtsp/server_command_session.go +++ b/pkg/rtsp/server_command_session.go @@ -10,6 +10,7 @@ package rtsp import ( "bufio" + "bytes" "fmt" "net" "strings" @@ -66,10 +67,12 @@ type ServerCommandSession struct { pubSession *PubSession subSession *SubSession - describeSeq string // only for sub session + describeSeq string // only for sub session + isWebSocket bool + websocketKey string } -func NewServerCommandSession(observer IServerCommandSessionObserver, conn net.Conn, authConf ServerAuthConfig) *ServerCommandSession { +func NewServerCommandSession(observer IServerCommandSessionObserver, conn net.Conn, authConf ServerAuthConfig, iswebsocket bool, websocketKey string) *ServerCommandSession { uk := base.GenUkRtspServerCommandSession() s := &ServerCommandSession{ uniqueKey: uk, @@ -79,9 +82,11 @@ func NewServerCommandSession(observer IServerCommandSessionObserver, conn net.Co option.ReadBufSize = serverCommandSessionReadBufSize option.WriteChanSize = serverCommandSessionWriteChanSize }), + isWebSocket: iswebsocket, + websocketKey: websocketKey, } - Log.Infof("[%s] lifecycle new rtsp ServerSession. session=%p, laddr=%s, raddr=%s", uk, s, conn.LocalAddr().String(), conn.RemoteAddr().String()) + Log.Infof("[%s] lifecycle new rtsp ServerSession. session=%p, laddr=%s, raddr=%s, iswebsocket:%v", uk, s, conn.LocalAddr().String(), conn.RemoteAddr().String(), iswebsocket) return s } @@ -102,6 +107,10 @@ func (session *ServerCommandSession) FeedSdp(b []byte) { // // 使用RTSP TCP命令连接,向对端发送RTP数据 func (session *ServerCommandSession) WriteInterleavedPacket(packet []byte, channel int) error { + if session.isWebSocket { + respLen := len(packInterleaved(channel, packet)) + session.writeWsFrameHeader(respLen) + } _, err := session.conn.Write(packInterleaved(channel, packet)) return err } @@ -157,28 +166,48 @@ func (session *ServerCommandSession) runCmdLoop() error { Loop: for { - isInterleaved, packet, channel, err := readInterleaved(r) - if err != nil { - Log.Errorf("[%s] read interleaved error. err=%+v", session.uniqueKey, err) - break Loop - } - if isInterleaved { - if session.pubSession != nil { - session.pubSession.HandleInterleavedPacket(packet, int(channel)) - } else if session.subSession != nil { - session.subSession.HandleInterleavedPacket(packet, int(channel)) - } else { - Log.Errorf("[%s] read interleaved packet but pub or sub not exist.", session.uniqueKey) + var requestCtx nazahttp.HttpReqMsgCtx + if session.isWebSocket { + // 解析出websocket的body信息 + payload, err := base.ReadWsPayload(r) + if err != nil { + Log.Errorf("[%s] read ws payload error. err=%+v", session.uniqueKey, err) break Loop } - continue - } - // 读取一个message - requestCtx, err := nazahttp.ReadHttpRequestMessage(r) - if err != nil { - Log.Errorf("[%s] read rtsp message error. err=%+v", session.uniqueKey, err) - break Loop + // 读取一个message + reader := bytes.NewReader(payload) + rr := bufio.NewReader(reader) + requestCtx, err = nazahttp.ReadHttpRequestMessage(rr) + if err != nil { + Log.Errorf("[%s] read rtsp message error. err=%+v", session.uniqueKey, err) + break Loop + } + } else { + isInterleaved, packet, channel, err := readInterleaved(r) + if err != nil { + Log.Errorf("[%s] read interleaved error. err=%+v", session.uniqueKey, err) + break Loop + } + + if isInterleaved { + if session.pubSession != nil { + session.pubSession.HandleInterleavedPacket(packet, int(channel)) + } else if session.subSession != nil { + session.subSession.HandleInterleavedPacket(packet, int(channel)) + } else { + Log.Errorf("[%s] read interleaved packet but pub or sub not exist.", session.uniqueKey) + break Loop + } + continue + } + + // 读取一个message + requestCtx, err = nazahttp.ReadHttpRequestMessage(r) + if err != nil { + Log.Errorf("[%s] read rtsp message error. err=%+v", session.uniqueKey, err) + break Loop + } } Log.Debugf("[%s] read http request. method=%s, uri=%s, version=%s, headers=%+v, body=%s", @@ -226,6 +255,10 @@ Loop: func (session *ServerCommandSession) handleOptions(requestCtx nazahttp.HttpReqMsgCtx) error { Log.Infof("[%s] < R OPTIONS", session.uniqueKey) resp := PackResponseOptions(requestCtx.Headers.Get(HeaderCSeq)) + if session.isWebSocket { + respLen := len([]byte(resp)) + session.writeWsFrameHeader(respLen) + } _, err := session.conn.Write([]byte(resp)) return err } @@ -269,6 +302,10 @@ func (session *ServerCommandSession) handleDescribe(requestCtx nazahttp.HttpReqM } if authresp != "" { + if session.isWebSocket { + respLen := len([]byte(authresp)) + session.writeWsFrameHeader(respLen) + } _, err := session.conn.Write([]byte(authresp)) return err } @@ -301,6 +338,10 @@ func (session *ServerCommandSession) feedSdp(rawSdp []byte) error { session.subSession.InitWithSdp(sdpCtx) resp := PackResponseDescribe(session.describeSeq, string(rawSdp)) + if session.isWebSocket { + respLen := len([]byte(resp)) + session.writeWsFrameHeader(respLen) + } _, err := session.conn.Write([]byte(resp)) return err } @@ -371,6 +412,10 @@ func (session *ServerCommandSession) handleSetup(requestCtx nazahttp.HttpReqMsgC } resp := PackResponseSetup(requestCtx.Headers.Get(HeaderCSeq), htv) + if session.isWebSocket { + respLen := len([]byte(resp)) + session.writeWsFrameHeader(respLen) + } _, err = session.conn.Write([]byte(resp)) return err } @@ -406,6 +451,10 @@ func (session *ServerCommandSession) handleSetup(requestCtx nazahttp.HttpReqMsgC } resp := PackResponseSetup(requestCtx.Headers.Get(HeaderCSeq), htv) + if session.isWebSocket { + respLen := len([]byte(resp)) + session.writeWsFrameHeader(respLen) + } _, err = session.conn.Write([]byte(resp)) return err } @@ -433,6 +482,10 @@ func (session *ServerCommandSession) handlePlay(requestCtx nazahttp.HttpReqMsgCt return err } resp := PackResponsePlay(requestCtx.Headers.Get(HeaderCSeq)) + if session.isWebSocket { + respLen := len([]byte(resp)) + session.writeWsFrameHeader(respLen) + } _, err := session.conn.Write([]byte(resp)) return err } @@ -440,6 +493,23 @@ func (session *ServerCommandSession) handlePlay(requestCtx nazahttp.HttpReqMsgCt func (session *ServerCommandSession) handleTeardown(requestCtx nazahttp.HttpReqMsgCtx) error { Log.Infof("[%s] < R TEARDOWN", session.uniqueKey) resp := PackResponseTeardown(requestCtx.Headers.Get(HeaderCSeq)) + if session.isWebSocket { + respLen := len([]byte(resp)) + session.writeWsFrameHeader(respLen) + } _, err := session.conn.Write([]byte(resp)) return err } + +func (session *ServerCommandSession) writeWsFrameHeader(respLen int) { + wsHeader := base.WsHeader{ + Fin: true, + Rsv1: false, + Rsv2: false, + Rsv3: false, + Opcode: base.Wso_Binary, + PayloadLength: uint64(respLen), + Masked: false, + } + session.conn.Write(base.MakeWsFrameHeader(wsHeader)) +} diff --git a/pkg/rtsp/websocket_server.go b/pkg/rtsp/websocket_server.go new file mode 100644 index 00000000..2874e6e2 --- /dev/null +++ b/pkg/rtsp/websocket_server.go @@ -0,0 +1,110 @@ +package rtsp + +import ( + "net" + "net/http" + "strings" + + "github.com/q191201771/lal/pkg/base" +) + +type WebsocketServer struct { + addr string + observer IServerObserver + + ln net.Listener + auth ServerAuthConfig + httpServer http.Server +} + +func NewWebsocketServer(addr string, observer IServerObserver, auth ServerAuthConfig) *WebsocketServer { + return &WebsocketServer{ + addr: addr, + observer: observer, + auth: auth, + } +} + +func (s *WebsocketServer) Listen() (err error) { + s.ln, err = net.Listen("tcp", s.addr) + if err != nil { + return + } + Log.Infof("start ws rtsp server listen. addr=%s", s.addr) + + server := http.Server{ + Handler: http.HandlerFunc(s.HandleWebsocket), + } + server.Serve(s.ln) + return +} + +func (s *WebsocketServer) HandleWebsocket(w http.ResponseWriter, r *http.Request) { + conn, bio, err := w.(http.Hijacker).Hijack() + if err != nil { + Log.Errorf("hijack failed. err=%+v", err) + return + } + if bio.Reader.Buffered() != 0 || bio.Writer.Buffered() != 0 { + Log.Errorf("hijack but buffer not empty. rb=%d, wb=%d", bio.Reader.Buffered(), bio.Writer.Buffered()) + } + + var ( + isWebSocket bool + webSocketKey string + ) + // 火狐浏览器 Connection = [keep-alive, Upgrade] + if strings.Contains(r.Header.Get("Connection"), "Upgrade") && r.Header.Get("Upgrade") == "websocket" { + isWebSocket = true + webSocketKey = r.Header.Get("Sec-WebSocket-Key") + } + + session := NewServerCommandSession(s, conn, s.auth, isWebSocket, webSocketKey) + s.observer.OnNewRtspSessionConnect(session) + + session.conn.Write(base.UpdateWebSocketHeader(webSocketKey, "rtsp")) + + err = session.RunLoop() + Log.Info(err) + + if session.pubSession != nil { + s.observer.OnDelRtspPubSession(session.pubSession) + _ = session.pubSession.Dispose() + } else if session.subSession != nil { + s.observer.OnDelRtspSubSession(session.subSession) + _ = session.subSession.Dispose() + } + s.observer.OnDelRtspSession(session) + +} + +func (s *WebsocketServer) Dispose() { + if s.ln == nil { + return + } + if err := s.ln.Close(); err != nil { + Log.Error(err) + } +} + +// ----- ServerCommandSessionObserver ---------------------------------------------------------------------------------- + +func (s *WebsocketServer) OnNewRtspPubSession(session *PubSession) error { + return s.observer.OnNewRtspPubSession(session) +} + +func (s *WebsocketServer) OnNewRtspSubSessionDescribe(session *SubSession) (ok bool, sdp []byte) { + return s.observer.OnNewRtspSubSessionDescribe(session) +} + +func (s *WebsocketServer) OnNewRtspSubSessionPlay(session *SubSession) error { + return s.observer.OnNewRtspSubSessionPlay(session) +} + +func (s *WebsocketServer) OnDelRtspPubSession(session *PubSession) { + s.observer.OnDelRtspPubSession(session) +} + +func (s *WebsocketServer) OnDelRtspSubSession(session *SubSession) { + s.observer.OnDelRtspSubSession(session) +}