diff --git a/pkg/manager/router/router.go b/pkg/manager/router/router.go index e77a70b8..28c9d7a4 100644 --- a/pkg/manager/router/router.go +++ b/pkg/manager/router/router.go @@ -7,11 +7,6 @@ import ( "time" glist "github.com/bahlo/generic-list-go" - "github.com/pingcap/tiproxy/lib/util/errors" -) - -var ( - ErrNoInstanceToSelect = errors.New("no instances to route") ) // ConnEventReceiver receives connection events. diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index f5acb6e2..98996160 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -18,10 +18,6 @@ import ( "go.uber.org/zap" ) -var ( - ErrCapabilityNegotiation = errors.New("capability negotiation failed") -) - const unknownAuthPlugin = "auth_unknown_plugin" const requiredFrontendCaps = pnet.ClientProtocol41 const defRequiredBackendCaps = pnet.ClientDeprecateEOF @@ -76,10 +72,10 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili // The error cannot be sent to the client because the client only expects an initial handshake packet. // The only way is to log it and disconnect. logger.Error("require backend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredBackendCaps^commonCaps)) - return errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps^commonCaps) + return errors.Wrapf(ErrBackendCap, "require %s from backend", requiredBackendCaps^commonCaps) } if auth.requireBackendTLS && (backendCapability&pnet.ClientSSL == 0) { - return pnet.WrapUserError(errors.New("backend doesn't enable TLS"), requireTiDBTLSErrMsg) + return ErrBackendNoTLS } return nil } @@ -106,7 +102,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte frontendCapability := pnet.Capability(binary.LittleEndian.Uint32(pkt)) if isSSL { if _, err = clientIO.ServerTLSHandshake(frontendTLSConfig); err != nil { - return pnet.WrapUserError(err, err.Error()) + return errors.Wrap(ErrClientHandshake, err) } pkt, _, err = clientIO.ReadSSLRequestOrHandshakeResp() if err != nil { @@ -125,7 +121,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte if writeErr := clientIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_NOT_SUPPORTED_AUTH_MODE)); writeErr != nil { return writeErr } - return errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps) + return errors.Wrapf(ErrClientCap, "require %s from frontend", requiredFrontendCaps&^commonCaps) } commonCaps := frontendCapability & proxyCapability if frontendCapability^commonCaps != 0 { @@ -147,10 +143,10 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte if errors.As(err, &warning) { logger.Warn("parse handshake response encounters error", zap.Error(err)) } else if err != nil { - return pnet.WrapUserError(err, parsePktErrMsg) + return err } if err = handshakeHandler.HandleHandshakeResp(cctx, clientResp); err != nil { - return pnet.WrapUserError(err, err.Error()) + return errors.Wrap(ErrProxyErr, err) } auth.user = clientResp.User auth.dbname = clientResp.DB @@ -163,29 +159,28 @@ RECONNECT: // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic. backendIO, err := getBackendIO(cctx, auth, clientResp) if err != nil { - return pnet.WrapUserError(err, connectErrMsg) + return err } backendIO.ResetSequence() // write proxy header if err := auth.writeProxyProtocol(clientIO, backendIO); err != nil { - return pnet.WrapUserError(err, handshakeErrMsg) + return err } // read backend initial handshake serverPkt, backendCapability, err := auth.readInitialHandshake(backendIO) if err != nil { - if IsMySQLError(err) { + if pnet.IsMySQLError(err) { if writeErr := clientIO.WritePacket(serverPkt, true); writeErr != nil { - err = writeErr + return writeErr } - return err } - return pnet.WrapUserError(err, handshakeErrMsg) + return err } if err := auth.verifyBackendCaps(logger, backendCapability); err != nil { - return pnet.WrapUserError(err, capabilityErrMsg) + return err } if common := proxyCapability & backendCapability; (proxyCapability^common)&^pnet.ClientSSL != 0 { @@ -207,7 +202,7 @@ RECONNECT: // Copy the auth data so that the backend can set correct `using password` in the error message. unknownAuthPlugin, clientResp.AuthData, 0, ); err != nil { - return pnet.WrapUserError(err, handshakeErrMsg) + return err } // forward other packets @@ -220,16 +215,18 @@ loop: // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence if pktIdx == 0 && errors.Is(err, pnet.ErrInvalidSequence) { - return pnet.WrapUserError(err, checkPPV2ErrMsg) + return errors.Wrap(ErrBackendPPV2, err) } return err } - var packetErr error + var packetErr *mysql.MyError if serverPkt[0] == pnet.ErrHeader.Byte() { packetErr = pnet.ParseErrorPacket(serverPkt) - if handshakeHandler.HandleHandshakeErr(cctx, packetErr.(*mysql.MyError)) { - logger.Warn("handle handshake error, start reconnect", zap.Error(err)) - backendIO.Close() + if handshakeHandler.HandleHandshakeErr(cctx, packetErr) { + logger.Warn("handle handshake error, start reconnect", zap.Error(packetErr)) + if closeErr := backendIO.Close(); closeErr != nil { + logger.Warn("close backend error", zap.Error(closeErr)) + } goto RECONNECT } } @@ -238,17 +235,17 @@ loop: return err } if packetErr != nil { - return packetErr + return errors.Wrap(ErrClientAuthFail, packetErr) } pktIdx++ switch serverPkt[0] { case pnet.OKHeader.Byte(): if err := setCompress(clientIO, auth.capability, auth.zstdLevel); err != nil { - return err + return errors.Wrap(ErrClientHandshake, err) } if err := setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil { - return err + return errors.Wrap(ErrBackendHandshake, err) } return nil default: // mysql.AuthSwitchRequest, ShaCommand @@ -276,7 +273,7 @@ func forwardMsg(srcIO, destIO *pnet.PacketIO) (data []byte, err error) { func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, backendTLSConfig *tls.Config, sessionToken string) error { if len(sessionToken) == 0 { - return errors.New("session token is empty") + return errors.Wrapf(ErrBackendHandshake, "session token is empty") } // write proxy header @@ -301,17 +298,20 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac } if err = auth.handleSecondAuthResult(backendIO); err == nil { - return setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel) + if err = setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil { + return errors.Wrap(ErrBackendHandshake, err) + } } - return err + return errors.Wrap(ErrBackendHandshake, err) } func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) { if serverPkt, err = backendIO.ReadPacket(); err != nil { + err = errors.Wrap(ErrBackendHandshake, err) return } if pnet.IsErrorPacket(serverPkt[0]) { - err = pnet.ParseErrorPacket(serverPkt) + err = errors.Wrap(ErrBackendHandshake, pnet.ParseErrorPacket(serverPkt)) return } capability, _, _ = pnet.ParseInitialHandshake(serverPkt) @@ -346,7 +346,7 @@ func (auth *Authenticator) writeAuthHandshake( var enableTLS bool if auth.requireBackendTLS { if backendTLSConfig == nil { - return pnet.WrapUserError(errors.New("tiproxy doesn't enable TLS"), requireProxyTLSErrMsg) + return ErrProxyNoTLS } enableTLS = true } else { @@ -358,7 +358,7 @@ func (auth *Authenticator) writeAuthHandshake( pkt = pnet.MakeHandshakeResponse(resp) // write SSL Packet if err := backendIO.WritePacket(pkt[:32], true); err != nil { - return err + return errors.Wrap(ErrBackendHandshake, err) } // Send TLS / SSL request packet. The server must have supported TLS. tcfg := backendTLSConfig.Clone() @@ -370,7 +370,7 @@ func (auth *Authenticator) writeAuthHandshake( if err := backendIO.ClientTLSHandshake(tcfg); err != nil { // tiproxy pp enabled, tidb pp disabled, tls enabled => tls handshake encounters unrecognized packet // tiproxy pp disabled, tidb pp enabled, tls enabled => tls handshake encounters unrecognized packet - return pnet.WrapUserError(err, checkPPV2ErrMsg) + return errors.Wrap(ErrBackendPPV2, err) } } else { resp.Capability &= ^pnet.ClientSSL @@ -378,7 +378,10 @@ func (auth *Authenticator) writeAuthHandshake( } // write handshake resp - return backendIO.WritePacket(pkt, true) + if err := backendIO.WritePacket(pkt, true); err != nil { + return errors.Wrap(ErrBackendHandshake, err) + } + return nil } func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) error { @@ -393,7 +396,7 @@ func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) erro case pnet.ErrHeader.Byte(): return pnet.ParseErrorPacket(data) default: // mysql.AuthSwitchRequest, ShaCommand: - return errors.Errorf("read unexpected command: %#x", data[0]) + return errors.Wrapf(mysql.ErrMalformPacket, "read unexpected command: %#x", data[0]) } } diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index 47d3e99d..221068d0 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -8,7 +8,6 @@ import ( "testing" "github.com/pingcap/tidb/parser/mysql" - "github.com/pingcap/tiproxy/lib/util/errors" pnet "github.com/pingcap/tiproxy/pkg/proxy/net" "github.com/stretchr/testify/require" ) @@ -70,10 +69,14 @@ func TestUnsupportedCapability(t *testing.T) { for _, cfgs := range cfgOverriders { ts, clean := newTestSuite(t, tc, cfgs...) ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite) { - if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps { - require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation) - } else if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps { - require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation) + if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps { + require.ErrorIs(t, ts.mp.err, ErrClientCap) + require.Nil(t, ErrToClient(ts.mp.err)) + require.Equal(t, SrcClientHandshake, Error2Source(ts.mp.err)) + } else if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps { + require.ErrorIs(t, ts.mp.err, ErrBackendCap) + require.Equal(t, ErrBackendCap, ErrToClient(ts.mp.err)) + require.Equal(t, SrcBackendHandshake, Error2Source(ts.mp.err)) } else { require.NoError(t, ts.mc.err) require.NoError(t, ts.mp.err) @@ -311,6 +314,7 @@ func TestAuthFail(t *testing.T) { ts, clean := newTestSuite(t, tc, cfg) ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) { require.Equal(t, len(ts.mc.authData), len(ts.mb.authData)) + require.Equal(t, SrcClientAuthFail, Error2Source(ts.mp.err)) }) clean() } @@ -318,8 +322,9 @@ func TestAuthFail(t *testing.T) { func TestRequireBackendTLS(t *testing.T) { tests := []struct { - cfg cfgOverrider - errMsg string + cfg cfgOverrider + err error + src ErrorSource }{ { cfg: func(cfg *testConfig) { @@ -327,7 +332,8 @@ func TestRequireBackendTLS(t *testing.T) { cfg.proxyConfig.backendTLSConfig = nil cfg.backendConfig.capability |= pnet.ClientSSL }, - errMsg: requireProxyTLSErrMsg, + err: ErrProxyNoTLS, + src: SrcProxyErr, }, { cfg: func(cfg *testConfig) { @@ -335,7 +341,8 @@ func TestRequireBackendTLS(t *testing.T) { cfg.backendConfig.tlsConfig = nil cfg.backendConfig.capability &= ^pnet.ClientSSL }, - errMsg: requireTiDBTLSErrMsg, + err: ErrBackendNoTLS, + src: SrcBackendHandshake, }, { cfg: func(cfg *testConfig) { @@ -351,10 +358,9 @@ func TestRequireBackendTLS(t *testing.T) { for _, tt := range tests { ts, clean := newTestSuite(t, tc, tt.cfg) ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) { - if len(tt.errMsg) > 0 { - var userError *pnet.UserError - require.True(t, errors.As(ts.mp.err, &userError)) - require.Equal(t, tt.errMsg, userError.UserMsg()) + if tt.err != nil { + require.ErrorIs(t, ts.mp.err, tt.err) + require.Equal(t, tt.src, Error2Source(ts.mp.err)) } else { require.NoError(t, ts.mp.err) } @@ -401,9 +407,9 @@ func TestProxyProtocol(t *testing.T) { // TiDB proxy-protocol can be set unfallbackable, but TiProxy proxy-protocol is always fallbackable. // So when backend enables proxy-protocol and proxy disables it, it still works well. if ts.mp.bcConfig.ProxyProtocol && !ts.mb.proxyProtocol { - var userError *pnet.UserError - require.True(t, errors.As(ts.mp.err, &userError)) - require.Equal(t, checkPPV2ErrMsg, userError.UserMsg()) + err := ErrToClient(ts.mp.err) + require.Equal(t, ErrBackendPPV2, err) + require.Equal(t, SrcBackendHandshake, Error2Source(err)) } else { require.NoError(t, ts.mp.err) } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 5123734d..e0dc05e8 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -148,7 +148,7 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler // There are 2 types of signals, which may be sent concurrently. signalReceived: make(chan signalType, signalTypeNums), redirectResCh: make(chan *redirectResult, 1), - quitSource: SrcClientQuit, + quitSource: SrcNone, } mgr.SetValue(ConnContextKeyConnID, connectionID) return mgr @@ -170,13 +170,16 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe mgr.clientIO = clientIO err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), mgr, clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) if err != nil { - mgr.setQuitSourceByErr(err) - mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err) - clientIO.WriteUserError(err) + src := Error2Source(err) + mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err, src) + // For some errors, convert them to MySQL errors and send them to the client. + if clientErr := ErrToClient(err); clientErr != nil { + clientIO.WriteUserError(clientErr) + } + mgr.quitSource = src return err } - mgr.resetQuitSource() - mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil) + mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil, SrcNone) mgr.cmdProcessor.capability = mgr.authenticator.capability childCtx, cancelFunc := context.WithCancel(ctx) @@ -191,7 +194,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { r, err := mgr.handshakeHandler.GetRouter(cctx, resp) if err != nil { - return nil, pnet.WrapUserError(err, err.Error()) + return nil, errors.Wrap(ErrProxyErr, err) } // Reasons to wait: // - The TiDB instances may not be initialized yet @@ -211,17 +214,17 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato addr, err = selector.Next() } if err != nil { - return nil, backoff.Permanent(pnet.WrapUserError(err, err.Error())) + return nil, backoff.Permanent(errors.Wrap(ErrProxyErr, err)) } if addr == "" { - return nil, router.ErrNoInstanceToSelect + return nil, ErrProxyNoBackend } var cn net.Conn cn, err = net.DialTimeout("tcp", addr, DialTimeout) selector.Finish(mgr, err == nil) if err != nil { - return nil, errors.Wrapf(err, "dial backend %s error", addr) + return nil, errors.Wrap(ErrBackendHandshake, errors.Wrapf(err, "dial backend %s error", addr)) } // NOTE: should use DNS name as much as possible @@ -235,8 +238,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), func(err error, d time.Duration) { origErr = err - mgr.setQuitSourceByErr(err) - mgr.handshakeHandler.OnHandshake(cctx, addr, err) + mgr.handshakeHandler.OnHandshake(cctx, addr, err, Error2Source(err)) }, ) cancel() @@ -285,7 +287,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) ( addCmdMetrics(cmd, mgr.ServerAddr(), startTime) } if err != nil { - if !IsMySQLError(err) { + if !pnet.IsMySQLError(err) { return } else { mgr.logger.Debug("got a mysql error", zap.Error(err), zap.Stringer("cmd", cmd)) @@ -305,7 +307,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) ( mgr.authenticator.capability &^= pnet.ClientMultiStatements mgr.cmdProcessor.capability &^= pnet.ClientMultiStatements default: - err = errors.Errorf("unrecognized set_option value:%d", val) + err = errors.Wrapf(gomysql.ErrMalformPacket, "unrecognized set_option value:%d", val) return } case pnet.ComChangeUser: @@ -321,7 +323,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) ( // Execute the held request no matter redirection succeeds or not. _, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), false) addCmdMetrics(cmd, mgr.ServerAddr(), startTime) - if err != nil && !IsMySQLError(err) { + if err != nil && !pnet.IsMySQLError(err) { return } } else if mgr.closeStatus.Load() == statusNotifyClose { @@ -431,7 +433,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { // If the backend connection is closed, also close the client connection. // Otherwise, if the client is idle, the mgr will keep retrying. if errors.Is(rs.err, net.ErrClosed) || pnet.IsDisconnectError(rs.err) || errors.Is(rs.err, os.ErrDeadlineExceeded) { - mgr.quitSource = SrcBackendQuit + mgr.quitSource = SrcBackendNetwork if ignoredErr := mgr.clientIO.GracefulClose(); ignoredErr != nil { mgr.logger.Warn("graceful close client IO error", zap.Stringer("client_addr", mgr.clientIO.RemoteAddr()), zap.Error(ignoredErr)) } @@ -442,12 +444,10 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { return } - defer mgr.resetQuitSource() var cn net.Conn cn, rs.err = net.DialTimeout("tcp", rs.to, DialTimeout) if rs.err != nil { - mgr.quitSource = SrcBackendQuit - mgr.handshakeHandler.OnHandshake(mgr, rs.to, rs.err) + mgr.handshakeHandler.OnHandshake(mgr, rs.to, rs.err, SrcBackendNetwork) return } newBackendIO := pnet.NewPacketIO(cn, mgr.logger, mgr.config.ConnBufferSize, pnet.WithRemoteAddr(rs.to, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn)) @@ -455,8 +455,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, mgr.clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil { rs.err = mgr.initSessionStates(newBackendIO, sessionStates) } else { - mgr.setQuitSourceByErr(rs.err) - mgr.handshakeHandler.OnHandshake(mgr, newBackendIO.RemoteAddr().String(), rs.err) + mgr.handshakeHandler.OnHandshake(mgr, newBackendIO.RemoteAddr().String(), rs.err, Error2Source(rs.err)) } if rs.err != nil { if ignoredErr := newBackendIO.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) { @@ -469,7 +468,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { } mgr.backendIO.Store(newBackendIO) mgr.setKeepAlive(mgr.config.HealthyKeepAlive) - mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil) + mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil, SrcNone) } // The original db in the auth info may be dropped during the session, so we need to authenticate with the current db. @@ -553,7 +552,7 @@ func (mgr *BackendConnManager) checkBackendActive() { if !backendIO.IsPeerActive() { mgr.logger.Info("backend connection is closed, close client connection", zap.Stringer("client_addr", mgr.clientIO.RemoteAddr()), zap.Stringer("backend_addr", backendIO.RemoteAddr())) - mgr.quitSource = SrcBackendQuit + mgr.quitSource = SrcBackendNetwork if err := mgr.clientIO.GracefulClose(); err != nil { mgr.logger.Warn("graceful close client IO error", zap.Stringer("client_addr", mgr.clientIO.RemoteAddr()), zap.Error(err)) } @@ -627,7 +626,7 @@ func (mgr *BackendConnManager) Close() error { } mgr.wg.Wait() - handErr := mgr.handshakeHandler.OnConnClose(mgr) + handErr := mgr.handshakeHandler.OnConnClose(mgr, mgr.quitSource) var connErr error var addr string @@ -677,26 +676,16 @@ func (mgr *BackendConnManager) setKeepAlive(cfg config.KeepAlive) { } } -// quitSource will be read by OnHandshake and OnConnClose, so setQuitSourceByErr should be called before them. func (mgr *BackendConnManager) setQuitSourceByErr(err error) { - // Do not update the source if err is nil. It may be already be set. if err == nil { return } - if errors.Is(err, ErrBackendConn) { - mgr.quitSource = SrcBackendQuit - } else if IsMySQLError(err) { - mgr.quitSource = SrcClientErr - } else if !errors.Is(err, ErrClientConn) { - mgr.quitSource = SrcProxyErr + // The source may be already be set. + // E.g. quitSource is set before TiProxy shuts down and client connection error is caused by shutdown instead of network. + if mgr.quitSource != SrcNone { + return } -} - -func (mgr *BackendConnManager) resetQuitSource() { - // SrcClientQuit is by default. - // Sometimes ErrClientConn is caused by GracefulClose and the quitSource is already set. - // Error maybe set during handshake for OnHandshake. If handshake finally succeeds, we reset it. - mgr.quitSource = SrcClientQuit + mgr.quitSource = Error2Source(err) } func (mgr *BackendConnManager) UpdateLogger(fields ...zap.Field) { diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index f07f3329..9d702881 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -150,7 +150,7 @@ func (ts *backendMgrTester) redirectSucceed4Proxy(_, _ *pnet.PacketIO) error { ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventSucceed) require.NotEqual(ts.t, backend1, ts.mp.backendIO.Load()) - require.Equal(ts.t, SrcClientQuit, ts.mp.QuitSource()) + require.Equal(ts.t, SrcNone, ts.mp.QuitSource()) return nil } @@ -372,7 +372,7 @@ func TestRedirectInTxn(t *testing.T) { require.NoError(t, err) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventFail) require.Equal(t, backend1, ts.mp.backendIO.Load()) - require.Equal(t, SrcClientQuit, ts.mp.QuitSource()) + require.Equal(t, SrcNone, ts.mp.QuitSource()) return nil }, backend: func(packetIO *pnet.PacketIO) error { @@ -411,7 +411,7 @@ func TestConnectFail(t *testing.T) { }, { proxy: func(clientIO, backendIO *pnet.PacketIO) error { - require.Equal(t, SrcClientErr, ts.mp.QuitSource()) + require.Equal(t, SrcClientAuthFail, ts.mp.QuitSource()) return nil }, }, @@ -624,7 +624,7 @@ func TestCustomHandshake(t *testing.T) { }, { proxy: func(clientIO, backendIO *pnet.PacketIO) error { - require.Equal(t, SrcClientQuit, ts.mp.QuitSource()) + require.Equal(t, SrcNone, ts.mp.QuitSource()) return nil }, }, @@ -767,8 +767,8 @@ func TestHandlerReturnError(t *testing.T) { return router.NewStaticRouter(nil), nil } }, - errMsg: connectErrMsg, - quitSource: SrcProxyErr, + errMsg: ErrProxyNoBackend.Error(), + quitSource: SrcProxyNoBackend, }, } for _, test := range tests { @@ -847,12 +847,12 @@ func TestGetBackendIO(t *testing.T) { getRouter: func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { return rt, nil }, - onHandshake: func(connContext ConnContext, s string, err error) { + onHandshake: func(connContext ConnContext, s string, err error, src ErrorSource) { if err != nil && len(s) > 0 { badAddrs[s] = struct{}{} } if err != nil { - require.Equal(t, SrcProxyErr, connContext.QuitSource()) + require.Equal(t, SrcBackendHandshake, src) } }, } @@ -961,7 +961,7 @@ func TestBackendInactive(t *testing.T) { }, { proxy: func(clientIO, backendIO *pnet.PacketIO) error { - require.Equal(t, SrcBackendQuit, ts.mp.QuitSource()) + require.Equal(t, SrcBackendNetwork, ts.mp.QuitSource()) return nil }, }, diff --git a/pkg/proxy/backend/cmd_processor.go b/pkg/proxy/backend/cmd_processor.go index 74fde4b4..2b4eb923 100644 --- a/pkg/proxy/backend/cmd_processor.go +++ b/pkg/proxy/backend/cmd_processor.go @@ -6,7 +6,6 @@ package backend import ( "encoding/binary" - gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/tidb/parser/mysql" pnet "github.com/pingcap/tiproxy/pkg/proxy/net" "go.uber.org/zap" @@ -116,12 +115,3 @@ func (cp *CmdProcessor) hasPendingPreparedStmts() bool { } return false } - -// IsMySQLError returns true if the error is a MySQL error. -func IsMySQLError(err error) bool { - if err == nil { - return false - } - _, ok := err.(*gomysql.MyError) - return ok -} diff --git a/pkg/proxy/backend/cmd_processor_exec.go b/pkg/proxy/backend/cmd_processor_exec.go index 620bfd43..c82fb633 100644 --- a/pkg/proxy/backend/cmd_processor_exec.go +++ b/pkg/proxy/backend/cmd_processor_exec.go @@ -25,7 +25,7 @@ func (cp *CmdProcessor) executeCmd(request []byte, clientIO, backendIO *pnet.Pac var response []byte if _, response, err = cp.query(backendIO, "COMMIT"); err != nil { // If commit fails, forward the response to the client. - if IsMySQLError(err) { + if pnet.IsMySQLError(err) { if writeErr := clientIO.WritePacket(response, true); writeErr != nil { return false, writeErr } diff --git a/pkg/proxy/backend/cmd_processor_test.go b/pkg/proxy/backend/cmd_processor_test.go index d5050a72..9b09c2d1 100644 --- a/pkg/proxy/backend/cmd_processor_test.go +++ b/pkg/proxy/backend/cmd_processor_test.go @@ -165,6 +165,7 @@ func TestDirectQuery(t *testing.T) { }, c: func(t *testing.T, ts *testSuite) { require.Error(t, ts.mp.err) + require.Equal(t, SrcClientSQLErr, Error2Source(ts.mp.err)) require.NoError(t, ts.mb.err) }, }, @@ -1027,7 +1028,6 @@ func TestNetworkError(t *testing.T) { clientErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) require.True(t, pnet.IsDisconnectError(ts.mc.err)) - require.NotNil(t, ts.mp.err.(*pnet.UserError)) } backendErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) @@ -1039,10 +1039,13 @@ func TestNetworkError(t *testing.T) { ts, clean := newTestSuite(t, tc, clientExitCfg) ts.authenticateFirstTime(t, backendErrChecker) + require.Equal(t, SrcClientNetwork, Error2Source(ts.mp.err)) clean() ts, clean = newTestSuite(t, tc, backendExitCfg) ts.authenticateFirstTime(t, clientErrChecker) + require.Equal(t, ErrBackendHandshake, ErrToClient(ts.mp.err)) + require.Equal(t, SrcBackendNetwork, Error2Source(ts.mp.err)) clean() ts, clean = newTestSuite(t, tc, backendExitCfg) @@ -1051,10 +1054,12 @@ func TestNetworkError(t *testing.T) { ts, clean = newTestSuite(t, tc, clientExitCfg) ts.executeCmd(t, backendErrChecker) + require.Equal(t, SrcClientNetwork, Error2Source(ts.mp.err)) clean() - ts, clean = newTestSuite(t, tc, clientExitCfg) - ts.executeCmd(t, backendErrChecker) + ts, clean = newTestSuite(t, tc, backendExitCfg) + ts.executeCmd(t, clientErrChecker) + require.Equal(t, SrcBackendNetwork, Error2Source(ts.mp.err)) clean() ts, clean = newTestSuite(t, tc, backendExitCfg) diff --git a/pkg/proxy/backend/common_test.go b/pkg/proxy/backend/common_test.go index 16056bad..bf0fd8ab 100644 --- a/pkg/proxy/backend/common_test.go +++ b/pkg/proxy/backend/common_test.go @@ -59,11 +59,11 @@ func (tc *tcpConnSuite) newConn(t *testing.T, enableRoute bool) func() { if !enableRoute { backendConn, err := net.Dial("tcp", tc.backendListener.Addr().String()) require.NoError(t, err) - tc.proxyBIO = pnet.NewPacketIO(backendConn, lg, pnet.DefaultConnBufferSize) + tc.proxyBIO = pnet.NewPacketIO(backendConn, lg, pnet.DefaultConnBufferSize, pnet.WithWrapError(ErrBackendConn)) } clientConn, err := tc.proxyListener.Accept() require.NoError(t, err) - tc.proxyCIO = pnet.NewPacketIO(clientConn, lg, pnet.DefaultConnBufferSize) + tc.proxyCIO = pnet.NewPacketIO(clientConn, lg, pnet.DefaultConnBufferSize, pnet.WithWrapError(ErrClientConn)) }) wg.Run(func() { conn, err := net.Dial("tcp", tc.proxyListener.Addr().String()) diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index 4097126a..666d1199 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -4,20 +4,177 @@ package backend import ( + gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/tiproxy/lib/util/errors" + pnet "github.com/pingcap/tiproxy/pkg/proxy/net" ) -const ( - connectErrMsg = "No available TiDB instances, please make sure TiDB is available" - parsePktErrMsg = "TiProxy fails to parse the packet, please contact PingCAP" - handshakeErrMsg = "TiProxy fails to connect to TiDB, please make sure TiDB is available" - capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB" - requireProxyTLSErrMsg = "Require TLS enabled on TiProxy when require-backend-tls=true" - requireTiDBTLSErrMsg = "Require TLS enabled on TiDB when require-backend-tls=true" - checkPPV2ErrMsg = "TiProxy fails to connect to TiDB, please make sure TiDB proxy-protocol is set correctly. If this error still exists, please contact PingCAP" +// These errors may not be disconnection errors. They are used for marking whether the error comes from the client or the backend. +var ( + ErrClientConn = errors.New("this is an error from the client connection") + ErrBackendConn = errors.New("this is an error from the backend connection") ) +// These errors are used to track internal errors. var ( - ErrClientConn = errors.New("this is an error from client") - ErrBackendConn = errors.New("this is an error from backend") + ErrClientCap = errors.New("Verify client capability failed, please upgrade the client") + ErrClientHandshake = errors.New("Fails to handshake with the client") + ErrClientAuthFail = errors.New("Authentication fails") + ErrProxyErr = errors.New("Other serverless error") + ErrProxyNoBackend = errors.New("No available TiDB instances, please make sure TiDB is available") + ErrProxyNoTLS = errors.New("Require TLS enabled on TiProxy when require-backend-tls=true") + ErrBackendCap = errors.New("Verify TiDB capability failed, please upgrade TiDB") + ErrBackendHandshake = errors.New("TiProxy fails to connect to TiDB, please make sure TiDB is available") + ErrBackendNoTLS = errors.New("Require TLS enabled on TiDB when require-backend-tls=true") + ErrBackendPPV2 = errors.New("TiProxy fails to connect to TiDB, please make sure TiDB proxy-protocol is set correctly. If this error still exists, please contact PingCAP") +) + +// ErrToClient returns the error that needs to be sent to the client. +func ErrToClient(err error) error { + switch { + case pnet.IsMySQLError(err): + // If it's a MySQL error, it should be already sent to the client. + return nil + case errors.Is(err, ErrProxyNoBackend): + return ErrProxyNoBackend + case errors.Is(err, ErrProxyNoTLS): + return ErrProxyNoTLS + case errors.Is(err, ErrBackendCap): + return ErrBackendCap + case errors.Is(err, ErrBackendHandshake): + return ErrBackendHandshake + case errors.Is(err, ErrBackendNoTLS): + return ErrBackendNoTLS + case errors.Is(err, ErrBackendPPV2): + return ErrBackendPPV2 + case errors.Is(err, ErrProxyErr): + // The error is returned by HandshakeHandler/BackendFetcher and wrapped with ErrProxyErr. + return errors.Unwrap(err) + } + // For other errors, we don't send them to the client. + return nil +} + +type SourceComp int + +const ( + CompNone SourceComp = iota + CompClient + CompProxy + CompBackend ) + +type ErrorSource int + +const ( + // SrcNone includes: succeed for OnHandshake; client normally quit for OnConnClose + SrcNone ErrorSource = iota + // SrcClientNetwork includes: EOF; reset by peer; connection refused; io timeout + SrcClientNetwork + // SrcClientHandshake includes: client capability unsupported; TLS handshake fails + SrcClientHandshake + // SrcClientAuthFail includes: backend returns auth fail + SrcClientAuthFail + // SrcClientSQLErr includes: SQL error + SrcClientSQLErr + // SrcProxyQuit includes: proxy graceful shutdown + SrcProxyQuit + // SrcProxyMalformed includes: malformed packet; invalid sequence + SrcProxyMalformed + // SrcProxyNoBackend includes: no backends + SrcProxyNoBackend + // SrcProxyErr includes: HandshakeHandler returns error; proxy disables TLS; unexpected errors + SrcProxyErr + // SrcBackendNetwork includes: EOF; reset by peer; connection refused; io timeout + SrcBackendNetwork + // SrcBackendHandshake includes: dial failure; backend capability unsupported; backend disables TLS; TLS handshake fails; proxy protocol fails + SrcBackendHandshake +) + +// Error2Source returns the ErrorSource by the error. +func Error2Source(err error) ErrorSource { + if err == nil { + return SrcNone + } + // Disconnection errors may come from other errors such as ErrProxyNoBackend and ErrBackendHandshake. + // ErrClientConn and ErrBackendConn may include non-connection errors. + if pnet.IsDisconnectError(err) { + if errors.Is(err, ErrClientConn) { + return SrcClientNetwork + } else if errors.Is(err, ErrBackendConn) { + return SrcBackendNetwork + } + } + switch { + // ErrInvalidSequence and ErrMalformPacket may be wrapped with other errors such as ErrBackendHandshake. + case errors.Is(err, pnet.ErrInvalidSequence), errors.Is(err, gomysql.ErrMalformPacket): + // We assume the clients and TiDB are right and treat it as TiProxy bugs. + return SrcProxyMalformed + case errors.Is(err, ErrClientHandshake), errors.Is(err, ErrClientCap): + return SrcClientHandshake + case errors.Is(err, ErrClientAuthFail): + return SrcClientAuthFail + case errors.Is(err, ErrBackendHandshake), errors.Is(err, ErrBackendCap), errors.Is(err, ErrBackendNoTLS), errors.Is(err, ErrBackendPPV2): + return SrcBackendHandshake + case errors.Is(err, ErrProxyNoBackend): + return SrcProxyNoBackend + case pnet.IsMySQLError(err): + // ErrClientAuthFail and ErrBackendHandshake may also contain MySQL error. + return SrcClientSQLErr + default: + // All other untracked errors are proxy errors. + return SrcProxyErr + } +} + +// String is used for metrics labels and log. +func (es ErrorSource) String() string { + switch es { + case SrcNone: + return "success" + case SrcClientNetwork: + return "client network break" + case SrcClientHandshake: + return "client handshake fail" + case SrcClientAuthFail: + return "auth fail" + case SrcClientSQLErr: + return "SQL error" + case SrcProxyQuit: + return "proxy shutdown" + case SrcProxyMalformed: + return "malformed packet" + case SrcProxyNoBackend: + return "get backend fail" + case SrcProxyErr: + return "proxy error" + case SrcBackendNetwork: + return "backend network break" + case SrcBackendHandshake: + return "backend handshake fail" + } + return "unknown" +} + +// GetSourceComp returns which component does this error belong to. +func (es ErrorSource) GetSourceComp() SourceComp { + switch es { + case SrcClientNetwork, SrcClientHandshake, SrcClientAuthFail, SrcClientSQLErr: + return CompClient + case SrcProxyQuit, SrcProxyMalformed, SrcProxyNoBackend, SrcProxyErr: + return CompProxy + case SrcBackendNetwork, SrcBackendHandshake: + return CompBackend + default: + return CompNone + } +} + +// Normal returns whether this error source is expected. +func (es ErrorSource) Normal() bool { + switch es { + case SrcNone, SrcProxyQuit, SrcClientNetwork, SrcClientSQLErr: + return true + } + return false +} diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index cb2665ac..d400540d 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -12,6 +12,8 @@ import ( "go.uber.org/zap" ) +// Interfaces in this file are used for the serverless tier. + // Context keys. type ConnContextKey string @@ -21,41 +23,6 @@ const ( ConnContextKeyConnAddr ConnContextKey = "conn-addr" ) -type ErrorSource int - -const ( - // SrcClientQuit includes: client quit; bad client conn - SrcClientQuit ErrorSource = iota - // SrcClientErr includes: wrong password; mal format packet - SrcClientErr - // SrcProxyQuit includes: proxy graceful shutdown - SrcProxyQuit - // SrcProxyErr includes: cannot get backend list; capability negotiation - SrcProxyErr - // SrcBackendQuit includes: backend quit - SrcBackendQuit - // SrcBackendErr is reserved - SrcBackendErr -) - -func (es ErrorSource) String() string { - switch es { - case SrcClientQuit: - return "client quit" - case SrcClientErr: - return "client error" - case SrcProxyQuit: - return "proxy shutdown" - case SrcProxyErr: - return "proxy error" - case SrcBackendQuit: - return "backend quit" - case SrcBackendErr: - return "backend error" - } - return "unknown" -} - var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil) var _ HandshakeHandler = (*CustomHandshakeHandler)(nil) @@ -64,7 +31,6 @@ type ConnContext interface { ServerAddr() string ClientInBytes() uint64 ClientOutBytes() uint64 - QuitSource() ErrorSource UpdateLogger(fields ...zap.Field) SetValue(key, val any) Value(key any) any @@ -74,8 +40,8 @@ type HandshakeHandler interface { HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error HandleHandshakeErr(ctx ConnContext, err *gomysql.MyError) bool // return true means retry connect GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) - OnHandshake(ctx ConnContext, to string, err error) - OnConnClose(ctx ConnContext) error + OnHandshake(ctx ConnContext, to string, err error, src ErrorSource) + OnConnClose(ctx ConnContext, src ErrorSource) error OnTraffic(ctx ConnContext) GetCapability() pnet.Capability GetServerVersion() string @@ -111,13 +77,13 @@ func (handler *DefaultHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.Ha return ns.GetRouter(), nil } -func (handler *DefaultHandshakeHandler) OnHandshake(ConnContext, string, error) { +func (handler *DefaultHandshakeHandler) OnHandshake(ConnContext, string, error, ErrorSource) { } func (handler *DefaultHandshakeHandler) OnTraffic(ConnContext) { } -func (handler *DefaultHandshakeHandler) OnConnClose(ConnContext) error { +func (handler *DefaultHandshakeHandler) OnConnClose(ConnContext, ErrorSource) error { return nil } @@ -140,9 +106,9 @@ func (handler *DefaultHandshakeHandler) GetServerVersion() string { type CustomHandshakeHandler struct { getRouter func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) - onHandshake func(ConnContext, string, error) + onHandshake func(ConnContext, string, error, ErrorSource) onTraffic func(ConnContext) - onConnClose func(ConnContext) error + onConnClose func(ConnContext, ErrorSource) error handleHandshakeResp func(ctx ConnContext, resp *pnet.HandshakeResp) error handleHandshakeErr func(ctx ConnContext, err *gomysql.MyError) bool getCapability func() pnet.Capability @@ -156,9 +122,9 @@ func (h *CustomHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.Handshake return nil, errors.New("no router") } -func (h *CustomHandshakeHandler) OnHandshake(ctx ConnContext, addr string, err error) { +func (h *CustomHandshakeHandler) OnHandshake(ctx ConnContext, addr string, err error, src ErrorSource) { if h.onHandshake != nil { - h.onHandshake(ctx, addr, err) + h.onHandshake(ctx, addr, err, src) } } @@ -168,9 +134,9 @@ func (h *CustomHandshakeHandler) OnTraffic(ctx ConnContext) { } } -func (h *CustomHandshakeHandler) OnConnClose(ctx ConnContext) error { +func (h *CustomHandshakeHandler) OnConnClose(ctx ConnContext, src ErrorSource) error { if h.onConnClose != nil { - return h.onConnClose(ctx) + return h.onConnClose(ctx, src) } return nil } diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index f323e525..303a34f7 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -162,7 +162,7 @@ func (ts *testSuite) runAndCheck(t *testing.T, c checker, clientRunner, backendR require.NoError(t, ts.mc.err) require.NoError(t, ts.mb.err) if ts.mb.err != nil { - require.True(t, IsMySQLError(ts.mp.err)) + require.True(t, pnet.IsMySQLError(ts.mp.err)) } if clientRunner != nil && backendRunner != nil { // Ensure all the packets are forwarded. @@ -190,6 +190,9 @@ func (ts *testSuite) authenticateFirstTime(t *testing.T, c checker) { if ts.mc.capability&pnet.ClientConnectAttrs > 0 { require.Equal(t, ts.mc.attrs, ts.mb.attrs) } + if !ts.mb.authSucceed { + require.Equal(t, SrcClientAuthFail, Error2Source(ts.mp.err)) + } } } diff --git a/pkg/proxy/client/client_conn.go b/pkg/proxy/client/client_conn.go index a2fc4f50..b66afe16 100644 --- a/pkg/proxy/client/client_conn.go +++ b/pkg/proxy/client/client_conn.go @@ -57,11 +57,9 @@ func (cc *ClientConnection) Run(ctx context.Context) { clean: src := cc.connMgr.QuitSource() - switch src { - case backend.SrcClientQuit, backend.SrcClientErr, backend.SrcProxyQuit: - default: + if !src.Normal() { fields := cc.connMgr.ConnInfo() - fields = append(fields, zap.Stringer("quit source", src), zap.Error(err)) + fields = append(fields, zap.Stringer("quit_source", src), zap.Error(err)) cc.logger.Warn(msg, fields...) } } diff --git a/pkg/proxy/net/compress.go b/pkg/proxy/net/compress.go index 2f4dd407..fb6a9be5 100644 --- a/pkg/proxy/net/compress.go +++ b/pkg/proxy/net/compress.go @@ -8,6 +8,7 @@ import ( "compress/zlib" "io" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/klauspost/compress/zstd" "github.com/pingcap/tiproxy/lib/util/errors" "go.uber.org/zap" @@ -45,7 +46,7 @@ func (p *PacketIO) SetCompressionAlgorithm(algorithm CompressAlgorithm, zstdLeve p.readWriter = newCompressedReadWriter(p.readWriter, algorithm, zstdLevel, p.logger) case CompressionNone: default: - return errors.Errorf("Unknown compression algorithm %d", algorithm) + return errors.Wrapf(mysql.ErrMalformPacket, "Unknown compression algorithm %d", algorithm) } return nil } @@ -272,7 +273,7 @@ func (crw *compressedReadWriter) compress(data []byte) ([]byte, error) { compressWriter, err = zstd.NewWriter(&compressedPacket, zstd.WithEncoderLevel(crw.zstdLevel)) } if err != nil { - return nil, errors.WithStack(err) + return nil, errors.WithStack(errors.Wrap(mysql.ErrMalformPacket, err)) } if _, err = compressWriter.Write(data); err != nil { return nil, errors.WithStack(err) @@ -289,12 +290,12 @@ func (crw *compressedReadWriter) uncompress(data []byte, uncompressedLength int) switch crw.algorithm { case CompressionZlib: if compressedReader, err = zlib.NewReader(bytes.NewReader(data)); err != nil { - return errors.WithStack(err) + return errors.WithStack(errors.Wrap(mysql.ErrMalformPacket, err)) } case CompressionZstd: var decoder *zstd.Decoder if decoder, err = zstd.NewReader(bytes.NewReader(data)); err != nil { - return errors.WithStack(err) + return errors.WithStack(errors.Wrap(mysql.ErrMalformPacket, err)) } compressedReader = decoder.IOReadCloser() } diff --git a/pkg/proxy/net/error.go b/pkg/proxy/net/error.go index 00409b32..9c7f290d 100644 --- a/pkg/proxy/net/error.go +++ b/pkg/proxy/net/error.go @@ -15,35 +15,3 @@ var ( ErrCloseConn = errors.New("failed to close the connection") ErrHandshakeTLS = errors.New("failed to complete tls handshake") ) - -// UserError is returned to the client. -// err is used to log and userMsg is used to report to the user. -type UserError struct { - err error - userMsg string -} - -func WrapUserError(err error, userMsg string) *UserError { - if err == nil { - return nil - } - if ue, ok := err.(*UserError); ok { - return ue - } - return &UserError{ - err: err, - userMsg: userMsg, - } -} - -func (ue *UserError) UserMsg() string { - return ue.userMsg -} - -func (ue *UserError) Unwrap() error { - return ue.err -} - -func (ue *UserError) Error() string { - return ue.err.Error() -} diff --git a/pkg/proxy/net/mysql.go b/pkg/proxy/net/mysql.go index d9b56f56..cca298c6 100644 --- a/pkg/proxy/net/mysql.go +++ b/pkg/proxy/net/mysql.go @@ -32,7 +32,7 @@ func ParseInitialHandshake(data []byte) (Capability, uint64, string) { // skip min version serverVersion := string(data[1 : 1+bytes.IndexByte(data[1:], 0)]) pos := 1 + len(serverVersion) + 1 - connid := uint32(binary.LittleEndian.Uint32(data[pos : pos+4])) + connid := binary.LittleEndian.Uint32(data[pos : pos+4]) // skip salt first part // skip filter pos += 4 + 8 + 1 @@ -398,7 +398,7 @@ func ParseOKPacket(data []byte) uint16 { } // ParseErrorPacket transforms an error packet into a MyError object. -func ParseErrorPacket(data []byte) error { +func ParseErrorPacket(data []byte) *gomysql.MyError { e := new(gomysql.MyError) pos := 1 e.Code = binary.LittleEndian.Uint16(data[pos:]) @@ -433,6 +433,12 @@ func IsErrorPacket(firstByte byte) bool { return firstByte == ErrHeader.Byte() } +// IsMySQLError returns true if the error is a MySQL error. +func IsMySQLError(err error) bool { + var myerr *gomysql.MyError + return errors.As(err, &myerr) +} + // The connection attribute names that are logged. // https://dev.mysql.com/doc/mysql-perfschema-excerpt/8.2/en/performance-schema-connection-attribute-tables.html const ( diff --git a/pkg/proxy/net/mysql_test.go b/pkg/proxy/net/mysql_test.go index d777b62a..d951d1ab 100644 --- a/pkg/proxy/net/mysql_test.go +++ b/pkg/proxy/net/mysql_test.go @@ -6,6 +6,8 @@ package net import ( "testing" + gomysql "github.com/go-mysql-org/go-mysql/mysql" + "github.com/pingcap/tiproxy/lib/util/errors" "github.com/pingcap/tiproxy/lib/util/logger" "github.com/stretchr/testify/require" ) @@ -61,3 +63,12 @@ func TestLogAttrs(t *testing.T) { require.Contains(t, str, `"client_name": "libmysql"`) require.Contains(t, str, `"program_name": "mysql"`) } + +func TestMySQLError(t *testing.T) { + myerr := &gomysql.MyError{} + require.True(t, IsMySQLError(errors.Wrap(ErrHandshakeTLS, myerr))) + require.False(t, IsMySQLError(errors.Wrap(myerr, ErrHandshakeTLS))) + require.False(t, IsMySQLError(ErrHandshakeTLS)) + require.True(t, errors.Is(errors.Wrap(ErrHandshakeTLS, myerr), ErrHandshakeTLS)) + require.True(t, errors.Is(errors.Wrap(myerr, ErrHandshakeTLS), ErrHandshakeTLS)) +} diff --git a/pkg/proxy/net/net_err.go b/pkg/proxy/net/net_err.go index b1659356..157c7f05 100644 --- a/pkg/proxy/net/net_err.go +++ b/pkg/proxy/net/net_err.go @@ -4,7 +4,9 @@ package net import ( + "context" "io" + "os" "syscall" "github.com/pingcap/tiproxy/lib/util/errors" @@ -13,7 +15,8 @@ import ( // IsDisconnectError returns whether the error is caused by peer disconnection. func IsDisconnectError(err error) bool { switch { - case errors.Is(err, io.EOF), errors.Is(err, syscall.EPIPE), errors.Is(err, syscall.ECONNRESET): + case errors.Is(err, io.EOF), errors.Is(err, syscall.EPIPE), errors.Is(err, syscall.ECONNRESET), + errors.Is(err, os.ErrDeadlineExceeded), errors.Is(err, context.DeadlineExceeded): return true } return false diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 8b5013b1..201a6b33 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -333,7 +333,7 @@ func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, first for { header, err := p.readWriter.Peek(5) if err != nil { - return errors.Wrap(ErrReadConn, err) + return p.wrapErr(errors.Wrap(ErrReadConn, err)) } length := int(header[0]) | int(header[1])<<8 | int(header[2])<<16 end, needData := isEnd(header[4], length) @@ -343,30 +343,30 @@ func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, first // TODO: allocate a buffer from pool and return the buffer after `process`. data, err = p.ReadPacket() if err != nil { - return errors.Wrap(ErrReadConn, err) + return p.wrapErr(errors.Wrap(ErrReadConn, err)) } if err := dest.WritePacket(data, false); err != nil { - return errors.Wrap(ErrWriteConn, err) + return p.wrapErr(errors.Wrap(ErrWriteConn, err)) } } else { for { sequence, pktSequence := header[3], p.readWriter.Sequence() if sequence != pktSequence { - return ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence) + return p.wrapErr(ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence)) } p.readWriter.SetSequence(sequence + 1) // Sequence may be different (e.g. with compression) so we can't just copy the data to the destination. dest.readWriter.SetSequence(dest.readWriter.Sequence() + 1) p.limitReader.N = int64(length + 4) if _, err := dest.readWriter.ReadFrom(&p.limitReader); err != nil { - return errors.Wrap(ErrRelayConn, err) + return p.wrapErr(errors.Wrap(ErrRelayConn, err)) } // For large packets, continue. if length < MaxPayloadLen { break } if header, err = p.readWriter.Peek(4); err != nil { - return errors.Wrap(ErrReadConn, err) + return p.wrapErr(errors.Wrap(ErrReadConn, err)) } length = int(header[0]) | int(header[1])<<8 | int(header[2])<<16 } diff --git a/pkg/proxy/net/packetio_mysql.go b/pkg/proxy/net/packetio_mysql.go index f17ba3e3..bbf60f5a 100644 --- a/pkg/proxy/net/packetio_mysql.go +++ b/pkg/proxy/net/packetio_mysql.go @@ -85,7 +85,7 @@ func (p *PacketIO) ReadSSLRequestOrHandshakeResp() (pkt []byte, isSSL bool, err if len(pkt) < 32 { p.logger.Error("got malformed handshake response", zap.ByteString("packetData", pkt)) - err = WrapUserError(mysql.ErrMalformPacket, mysql.ErrMalformPacket.Error()) + err = mysql.ErrMalformPacket return } @@ -132,11 +132,7 @@ func (p *PacketIO) WriteUserError(err error) { if err == nil { return } - var ue *UserError - if !errors.As(err, &ue) { - return - } - myErr := mysql.NewError(mysql.ER_UNKNOWN_ERROR, ue.UserMsg()) + myErr := mysql.NewError(mysql.ER_UNKNOWN_ERROR, err.Error()) if writeErr := p.WriteErrPacket(myErr); writeErr != nil { p.logger.Error("writing error to client failed", zap.NamedError("mysql_err", err), zap.NamedError("write_err", writeErr)) } diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index 100ae4d1..4c092437 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -265,13 +265,13 @@ func TestPredefinedPacket(t *testing.T) { func(t *testing.T, cli *PacketIO) { data, err := cli.ReadPacket() require.NoError(t, err) - merr := ParseErrorPacket(data).(*mysql.MyError) + merr := ParseErrorPacket(data) require.Equal(t, uint16(mysql.ER_UNKNOWN_ERROR), merr.Code) require.Equal(t, "Unknown error", merr.Message) data, err = cli.ReadPacket() require.NoError(t, err) - merr = ParseErrorPacket(data).(*mysql.MyError) + merr = ParseErrorPacket(data) require.Equal(t, uint16(mysql.ER_UNKNOWN_ERROR), merr.Code) require.Equal(t, "test error", merr.Message) diff --git a/pkg/proxy/net/tls.go b/pkg/proxy/net/tls.go index 4ad416ec..9303ecf3 100644 --- a/pkg/proxy/net/tls.go +++ b/pkg/proxy/net/tls.go @@ -39,7 +39,7 @@ func (p *PacketIO) ClientTLSHandshake(tlsConfig *tls.Config) error { conn := &tlsInternalConn{p.readWriter} tlsConn := tls.Client(conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { - return errors.WithStack(errors.Wrap(ErrHandshakeTLS, err)) + return p.wrapErr(errors.Wrap(ErrHandshakeTLS, err)) } p.readWriter = newTLSReadWriter(p.readWriter, tlsConn) return nil