From df27427e6a15f2eee8dad6658472541cd3225a49 Mon Sep 17 00:00:00 2001 From: namelessssssssssss <1544669126@qq.com> Date: Mon, 29 May 2023 12:20:25 +0800 Subject: [PATCH] Add bi-stream call support for triple --- .../bootstrap/DefaultClientProxyInvoker.java | 2 +- .../sofa/rpc/client/AbstractCluster.java | 4 + .../alipay/sofa/rpc/common/RpcConstants.java | 16 ++ .../rpc/config/AbstractInterfaceConfig.java | 22 +- .../sofa/rpc/config/ConsumerConfig.java | 59 ++++- .../alipay/sofa/rpc/config/MethodConfig.java | 43 +++- .../sofa/rpc/config/ProviderConfig.java | 1 + .../sofa/rpc/message/MessageBuilder.java | 15 ++ .../sofa/rpc/transport/StreamHandler.java | 48 ++++ .../sofa/rpc/core/exception/RpcErrorType.java | 5 + .../rpc/filter/ConsumerGenericFilter.java | 2 +- .../proto/main/java/triple/GenericProto.java | 9 +- .../main/java/triple/GenericServiceGrpc.java | 125 +++++++++ .../java/triple/SofaGenericServiceTriple.java | 62 +++++ .../stream/ClientStreamObserverAdapter.java | 71 ++++++ .../ResponseSerializeStreamHandler.java | 64 +++++ .../rpc/server/triple/GenericServiceImpl.java | 238 ++++++++++++++++-- .../sofa/rpc/server/triple/TripleServer.java | 44 ++-- .../rpc/server/triple/UniqueIdInvoker.java | 18 +- .../transport/triple/TripleClientInvoker.java | 233 ++++++++++++----- .../alipay/sofa/rpc/utils/SofaProtoUtils.java | 24 ++ .../sofa/rpc/utils/TripleExceptionUtils.java | 32 +++ .../src/main/proto/transformer.proto | 5 + .../server/triple/GenericServiceImplTest.java | 4 +- .../src/main/proto/helloworld.proto | 2 + .../sofa/rpc/test/triple/GreeterImpl.java | 20 ++ .../rpc/test/triple/TripleServerTest.java | 53 ++++ .../sofa/rpc/triple/stream/ClientRequest.java | 36 +++ .../sofa/rpc/triple/stream/HelloService.java | 39 +++ .../rpc/triple/stream/HelloServiceImpl.java | 90 +++++++ .../rpc/triple/stream/ServerResponse.java | 36 +++ .../stream/TripleGenericStreamTest.java | 208 +++++++++++++++ 32 files changed, 1522 insertions(+), 108 deletions(-) create mode 100644 core/api/src/main/java/com/alipay/sofa/rpc/transport/StreamHandler.java create mode 100644 remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/message/triple/stream/ClientStreamObserverAdapter.java create mode 100644 remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/message/triple/stream/ResponseSerializeStreamHandler.java create mode 100644 remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/utils/TripleExceptionUtils.java create mode 100644 test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/ClientRequest.java create mode 100644 test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/HelloService.java create mode 100644 test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/HelloServiceImpl.java create mode 100644 test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/ServerResponse.java create mode 100644 test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/TripleGenericStreamTest.java diff --git a/bootstrap/bootstrap-api/src/main/java/com/alipay/sofa/rpc/bootstrap/DefaultClientProxyInvoker.java b/bootstrap/bootstrap-api/src/main/java/com/alipay/sofa/rpc/bootstrap/DefaultClientProxyInvoker.java index 0aadda9c9..0daa4bb56 100644 --- a/bootstrap/bootstrap-api/src/main/java/com/alipay/sofa/rpc/bootstrap/DefaultClientProxyInvoker.java +++ b/bootstrap/bootstrap-api/src/main/java/com/alipay/sofa/rpc/bootstrap/DefaultClientProxyInvoker.java @@ -92,7 +92,7 @@ protected void decorateRequest(SofaRequest request) { if (!consumerConfig.isGeneric()) { // 找到调用类型, generic的时候类型在filter里进行判断 - request.setInvokeType(consumerConfig.getMethodInvokeType(request.getMethodName())); + request.setInvokeType(consumerConfig.getMethodInvokeType(request)); } RpcInvokeContext invokeCtx = RpcInvokeContext.peekContext(); diff --git a/core-impl/client/src/main/java/com/alipay/sofa/rpc/client/AbstractCluster.java b/core-impl/client/src/main/java/com/alipay/sofa/rpc/client/AbstractCluster.java index 33fc4cc2a..60280f440 100644 --- a/core-impl/client/src/main/java/com/alipay/sofa/rpc/client/AbstractCluster.java +++ b/core-impl/client/src/main/java/com/alipay/sofa/rpc/client/AbstractCluster.java @@ -657,6 +657,10 @@ else if (RpcConstants.INVOKER_TYPE_FUTURE.equals(invokeType)) { // 放入线程上下文 RpcInternalContext.getContext().setFuture(future); response = buildEmptyResponse(request); + } else if (RpcConstants.INVOKER_TYPE_CLIENT_STREAMING.equals(invokeType) + || RpcConstants.INVOKER_TYPE_BI_STREAMING.equals(invokeType) + || RpcConstants.INVOKER_TYPE_SERVER_STREAMING.equals(invokeType)) { + response = transport.syncSend(request, Integer.MAX_VALUE); } else { throw new SofaRpcException(RpcErrorType.CLIENT_UNDECLARED_ERROR, "Unknown invoke type:" + invokeType); } diff --git a/core/api/src/main/java/com/alipay/sofa/rpc/common/RpcConstants.java b/core/api/src/main/java/com/alipay/sofa/rpc/common/RpcConstants.java index b09cdfe7f..8c77e3223 100644 --- a/core/api/src/main/java/com/alipay/sofa/rpc/common/RpcConstants.java +++ b/core/api/src/main/java/com/alipay/sofa/rpc/common/RpcConstants.java @@ -121,6 +121,22 @@ public class RpcConstants { * 调用方式:future */ public static final String INVOKER_TYPE_FUTURE = "future"; + /** + * 调用方式:一元调用 + */ + public static final String INVOKER_TYPE_UNARY = "unary"; + /** + * 调用方式:客户端流 + */ + public static final String INVOKER_TYPE_CLIENT_STREAMING = "clientStream"; + /** + * 调用方式:服务端流 + */ + public static final String INVOKER_TYPE_SERVER_STREAMING = "serverStream"; + /** + * 调用方式:双向流 + */ + public static final String INVOKER_TYPE_BI_STREAMING = "bidirectionalStream"; /** * Hessian序列化 [不推荐] diff --git a/core/api/src/main/java/com/alipay/sofa/rpc/config/AbstractInterfaceConfig.java b/core/api/src/main/java/com/alipay/sofa/rpc/config/AbstractInterfaceConfig.java index 336292369..4cdf1674b 100644 --- a/core/api/src/main/java/com/alipay/sofa/rpc/config/AbstractInterfaceConfig.java +++ b/core/api/src/main/java/com/alipay/sofa/rpc/config/AbstractInterfaceConfig.java @@ -212,6 +212,11 @@ public abstract class AbstractInterfaceConfig configValueCache = null; + /** + * 方法调用类型,(方法全名 - 调用类型) + */ + protected transient volatile Map methodCallType = null; + /** * 代理接口类,和T对应,主要针对泛化调用 */ @@ -247,6 +252,21 @@ public S setProxyClass(Class proxyClass) { return castThis(); } + /** + * Cache the call type of interface methods + */ + protected void loadMethodCallType(Class interfaceClass){ + Method[] methods = interfaceClass.getDeclaredMethods(); + this.methodCallType = new ConcurrentHashMap<>(); + for(Method method :methods) { + methodCallType.put(method.getName(),MethodConfig.mapStreamType(method,RpcConstants.INVOKER_TYPE_SYNC)); + } + } + + public String getMethodCallType(String methodName) { + return methodCallType.get(methodName); + } + /** * Gets application. * @@ -1015,7 +1035,7 @@ public Object getMethodConfigValue(String methodName, String configKey) { * @param key the key * @return the string */ - private String buildmkey(String methodName, String key) { + protected String buildmkey(String methodName, String key) { return RpcConstants.HIDE_KEY_PREFIX + methodName + RpcConstants.HIDE_KEY_PREFIX + key; } diff --git a/core/api/src/main/java/com/alipay/sofa/rpc/config/ConsumerConfig.java b/core/api/src/main/java/com/alipay/sofa/rpc/config/ConsumerConfig.java index 1b72bc901..e151665a1 100644 --- a/core/api/src/main/java/com/alipay/sofa/rpc/config/ConsumerConfig.java +++ b/core/api/src/main/java/com/alipay/sofa/rpc/config/ConsumerConfig.java @@ -27,6 +27,7 @@ import com.alipay.sofa.rpc.common.utils.ExceptionUtils; import com.alipay.sofa.rpc.common.utils.StringUtils; import com.alipay.sofa.rpc.core.invoke.SofaResponseCallback; +import com.alipay.sofa.rpc.core.request.SofaRequest; import com.alipay.sofa.rpc.listener.ChannelListener; import com.alipay.sofa.rpc.listener.ConsumerStateListener; import com.alipay.sofa.rpc.listener.ProviderInfoListener; @@ -935,12 +936,62 @@ public SofaResponseCallback getMethodOnreturn(String methodName) { /** * Gets the call type corresponding to the method name * - * @param methodName the method name + * @param sofaRequest the request * @return the call type */ - public String getMethodInvokeType(String methodName) { - return (String) getMethodConfigValue(methodName, RpcConstants.CONFIG_KEY_INVOKE_TYPE, - getInvokeType()); + public String getMethodInvokeType(SofaRequest sofaRequest) { + String methodName = sofaRequest.getMethodName(); + + String invokeType = (String) getMethodConfigValue(methodName, RpcConstants.CONFIG_KEY_INVOKE_TYPE, null); + + if (invokeType == null) { + invokeType = getAndCacheCallType(sofaRequest); + } + + return invokeType; + } + + /** + * Get and cache the call type of certain method + * @param request RPC request + * @return request call type + */ + public String getAndCacheCallType(SofaRequest request) { + Method method = request.getMethod(); + String callType = MethodConfig + .mapStreamType( + method, + (String) getMethodConfigValue(request.getMethodName(), RpcConstants.CONFIG_KEY_INVOKE_TYPE, + getInvokeType()) + ); + //Method level config + updateAttribute(buildMethodConfigKey(request, RpcConstants.CONFIG_KEY_INVOKE_TYPE), callType, true); + return callType; + } + + /** + * 通过请求的目标方法构建方法配置key。该key使用内部配置格式。(以'.' 开头) + * @param request RPC请求 + * @return 方法配置名称,带方法参数列表 + */ + public String buildMethodConfigKey(SofaRequest request, String propertyKey) { + return "." + getMethodSignature(request.getMethod()) + "." + propertyKey; + } + + public static String getMethodSignature(Method method) { + Class[] parameterTypes = method.getParameterTypes(); + StringBuilder methodSignature = new StringBuilder(); + methodSignature.append(method.getName()).append("("); + + for (int i = 0; i < parameterTypes.length; i++) { + methodSignature.append(parameterTypes[i].getSimpleName()); + if (i < parameterTypes.length - 1) { + methodSignature.append(", "); + } + } + + methodSignature.append(")"); + return methodSignature.toString(); } /** diff --git a/core/api/src/main/java/com/alipay/sofa/rpc/config/MethodConfig.java b/core/api/src/main/java/com/alipay/sofa/rpc/config/MethodConfig.java index 611795c49..265995505 100644 --- a/core/api/src/main/java/com/alipay/sofa/rpc/config/MethodConfig.java +++ b/core/api/src/main/java/com/alipay/sofa/rpc/config/MethodConfig.java @@ -16,9 +16,15 @@ */ package com.alipay.sofa.rpc.config; +import com.alipay.sofa.rpc.common.RpcConstants; +import com.alipay.sofa.rpc.core.exception.RpcErrorType; +import com.alipay.sofa.rpc.core.exception.SofaRpcException; import com.alipay.sofa.rpc.core.invoke.SofaResponseCallback; +import com.alipay.sofa.rpc.transport.StreamHandler; import java.io.Serializable; +import java.lang.reflect.Method; +import java.util.Arrays; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -43,7 +49,8 @@ public class MethodConfig implements Serializable { protected Map parameters; /** - * The Timeout. 远程调用超时时间(毫秒) + * The Timeout. 远程调用超时时间(毫秒)。 + * 对于Stream调用,这个时间为整个调用的超时时长,而非stream内单个调用的时长。未在这个时长内完成(调用{@link StreamHandler#onFinish()})的Stream调用会认为超时并抛出异常。 */ protected Integer timeout; @@ -326,4 +333,38 @@ public MethodConfig setParameter(String key, String value) { public String getParameter(String key) { return parameters == null ? null : parameters.get(key); } + + /** + * Gets the stream call type of certain method + * @param method the method + * @return call type,server/client/bidirectional stream or default value. If not mapped to any stream call type, use the default value + */ + public static String mapStreamType(Method method, String defaultValue){ + Class[] paramClasses = method.getParameterTypes(); + Class returnClass = method.getReturnType(); + + int paramLen = paramClasses.length; + String callType; + + //BidirectionalStream & ClientStream + if(paramLen > 0 && StreamHandler.class.isAssignableFrom(paramClasses[0]) && StreamHandler.class.isAssignableFrom(returnClass)){ + if(paramLen > 1){ + throw new SofaRpcException(RpcErrorType.CLIENT_CALL_TYPE,"Bidirectional/Client stream method parameters can be only one StreamHandler."); + } + callType = RpcConstants.INVOKER_TYPE_BI_STREAMING; + } + //ServerStream + else if (paramLen > 1 && StreamHandler.class.isAssignableFrom(paramClasses[0]) && void.class == returnClass){ + callType = RpcConstants.INVOKER_TYPE_SERVER_STREAMING; + } + else if (StreamHandler.class.isAssignableFrom(returnClass) || Arrays.stream(paramClasses).anyMatch(StreamHandler.class::isAssignableFrom)) { + throw new SofaRpcException(RpcErrorType.CLIENT_CALL_TYPE, "StreamHandler can only at the specified location of parameter. Please check related docs."); + } + //Other call types + else { + callType = defaultValue; + } + + return callType; + } } diff --git a/core/api/src/main/java/com/alipay/sofa/rpc/config/ProviderConfig.java b/core/api/src/main/java/com/alipay/sofa/rpc/config/ProviderConfig.java index 33b0d6897..cc8872d0e 100644 --- a/core/api/src/main/java/com/alipay/sofa/rpc/config/ProviderConfig.java +++ b/core/api/src/main/java/com/alipay/sofa/rpc/config/ProviderConfig.java @@ -208,6 +208,7 @@ public T getRef() { */ public ProviderConfig setRef(T ref) { this.ref = ref; + loadMethodCallType(ref.getClass()); return this; } diff --git a/core/api/src/main/java/com/alipay/sofa/rpc/message/MessageBuilder.java b/core/api/src/main/java/com/alipay/sofa/rpc/message/MessageBuilder.java index 6607208fc..c143d5742 100644 --- a/core/api/src/main/java/com/alipay/sofa/rpc/message/MessageBuilder.java +++ b/core/api/src/main/java/com/alipay/sofa/rpc/message/MessageBuilder.java @@ -69,6 +69,21 @@ public static SofaRequest buildSofaRequest(Class clazz, Method method, Class[ return request; } + /** + * 根据一个请求的属性复制一个不包含具体方法实参的请求。 + * 复制以下属性:请求接口名、请求方法名、请求方法、方法参数类型 + * + * @param sofaRequest 被复制的请求实例 + */ + public static SofaRequest copyEmptyRequest(SofaRequest sofaRequest) { + SofaRequest request = new SofaRequest(); + request.setInterfaceName(sofaRequest.getInterfaceName()); + request.setMethodName(sofaRequest.getMethodName()); + request.setMethod(sofaRequest.getMethod()); + request.setMethodArgSigs(sofaRequest.getMethodArgSigs()); + return request; + } + /** * 构建rpc错误结果 * diff --git a/core/api/src/main/java/com/alipay/sofa/rpc/transport/StreamHandler.java b/core/api/src/main/java/com/alipay/sofa/rpc/transport/StreamHandler.java new file mode 100644 index 000000000..89494a07d --- /dev/null +++ b/core/api/src/main/java/com/alipay/sofa/rpc/transport/StreamHandler.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alipay.sofa.rpc.transport; + +/** + * StreamHandler, works just like gRPC StreamObserver. + */ +public interface StreamHandler { + + /** + * Sends a message, or defines the behavior when a message is received. + *

This method should never be called after {@link StreamHandler#onFinish()} has been invoked. + */ + void onMessage(T message); + + /** + * Note: This method MUST be invoked after the transport is complete. + * Failure to do so may result in unexpected errors. + *

+ * Signals that all messages have been sent/received normally, and closes this stream. + */ + void onFinish(); + + /** + * Signals an exception to terminate this stream, or defines the behavior when an error occurs. + *

+ * Once this method is invoked by one side, it can't send more messages, and the corresponding method on the other side will be triggered. + * Depending on the protocol implementation, it's possible that the other side can still call {@link StreamHandler#onMessage(Object)} after this method has been invoked, although this is not recommended. + *

+ * As a best practice, it is advised not to send any more information once this method is called. + * + */ + void onException(Throwable throwable); +} diff --git a/core/exception/src/main/java/com/alipay/sofa/rpc/core/exception/RpcErrorType.java b/core/exception/src/main/java/com/alipay/sofa/rpc/core/exception/RpcErrorType.java index 5356e1bbd..937165e29 100644 --- a/core/exception/src/main/java/com/alipay/sofa/rpc/core/exception/RpcErrorType.java +++ b/core/exception/src/main/java/com/alipay/sofa/rpc/core/exception/RpcErrorType.java @@ -93,6 +93,11 @@ public class RpcErrorType { */ public static final int CLIENT_NETWORK = 250; + /** + * 不支持的RPC调用方式异常 + */ + public static final int CLIENT_CALL_TYPE = 260; + /** * 客户端过滤器异常 */ diff --git a/remoting/remoting-bolt/src/main/java/com/alipay/sofa/rpc/filter/ConsumerGenericFilter.java b/remoting/remoting-bolt/src/main/java/com/alipay/sofa/rpc/filter/ConsumerGenericFilter.java index 0339415c7..1a30a39ca 100644 --- a/remoting/remoting-bolt/src/main/java/com/alipay/sofa/rpc/filter/ConsumerGenericFilter.java +++ b/remoting/remoting-bolt/src/main/java/com/alipay/sofa/rpc/filter/ConsumerGenericFilter.java @@ -90,7 +90,7 @@ public SofaResponse invoke(FilterInvoker invoker, SofaRequest request) throws So // 修正类型 ConsumerConfig consumerConfig = (ConsumerConfig) invoker.getConfig(); - String invokeType = consumerConfig.getMethodInvokeType(methodName); + String invokeType = consumerConfig.getMethodInvokeType(request); request.setInvokeType(invokeType); request.addRequestProp(RemotingConstants.HEAD_INVOKE_TYPE, invokeType); request.addRequestProp(REVISE_KEY, REVISE_VALUE); diff --git a/remoting/remoting-triple/build/generated/source/proto/main/java/triple/GenericProto.java b/remoting/remoting-triple/build/generated/source/proto/main/java/triple/GenericProto.java index 808f3a3a7..37f91d33e 100644 --- a/remoting/remoting-triple/build/generated/source/proto/main/java/triple/GenericProto.java +++ b/remoting/remoting-triple/build/generated/source/proto/main/java/triple/GenericProto.java @@ -36,9 +36,12 @@ public static void registerAllExtensions( "\n\021transformer.proto\"@\n\007Request\022\025\n\rserial" + "izeType\030\001 \001(\t\022\014\n\004args\030\002 \003(\014\022\020\n\010argTypes\030" + "\003 \003(\t\"=\n\010Response\022\025\n\rserializeType\030\001 \001(\t" + - "\022\014\n\004data\030\002 \001(\014\022\014\n\004type\030\003 \001(\t22\n\016GenericS" + - "ervice\022 \n\007generic\022\010.Request\032\t.Response\"\000" + - "B\030\n\006tripleB\014GenericProtoP\001b\006proto3" + "\022\014\n\004data\030\002 \001(\014\022\014\n\004type\030\003 \001(\t2\220\001\n\016Generic" + + "Service\022 \n\007generic\022\010.Request\032\t.Response\"" + + "\000\022,\n\017genericBiStream\022\010.Request\032\t.Respons" + + "e\"\000(\0010\001\022.\n\023genericServerStream\022\010.Request" + + "\032\t.Response\"\0000\001B\030\n\006tripleB\014GenericProtoP" + + "\001b\006proto3" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, diff --git a/remoting/remoting-triple/build/generated/source/proto/main/java/triple/GenericServiceGrpc.java b/remoting/remoting-triple/build/generated/source/proto/main/java/triple/GenericServiceGrpc.java index 75a2e5ba5..27d1c02c6 100644 --- a/remoting/remoting-triple/build/generated/source/proto/main/java/triple/GenericServiceGrpc.java +++ b/remoting/remoting-triple/build/generated/source/proto/main/java/triple/GenericServiceGrpc.java @@ -46,6 +46,68 @@ triple.Response> getGenericMethod() { return getGenericMethod; } + private static volatile io.grpc.MethodDescriptor getGenericBiStreamMethod; + + @io.grpc.stub.annotations.RpcMethod( + fullMethodName = SERVICE_NAME + '/' + "genericBiStream", + requestType = triple.Request.class, + responseType = triple.Response.class, + methodType = io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING) + public static io.grpc.MethodDescriptor getGenericBiStreamMethod() { + io.grpc.MethodDescriptor getGenericBiStreamMethod; + if ((getGenericBiStreamMethod = GenericServiceGrpc.getGenericBiStreamMethod) == null) { + synchronized (GenericServiceGrpc.class) { + if ((getGenericBiStreamMethod = GenericServiceGrpc.getGenericBiStreamMethod) == null) { + GenericServiceGrpc.getGenericBiStreamMethod = getGenericBiStreamMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING) + .setFullMethodName(generateFullMethodName(SERVICE_NAME, "genericBiStream")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + triple.Request.getDefaultInstance())) + .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + triple.Response.getDefaultInstance())) + .setSchemaDescriptor(new GenericServiceMethodDescriptorSupplier("genericBiStream")) + .build(); + } + } + } + return getGenericBiStreamMethod; + } + + private static volatile io.grpc.MethodDescriptor getGenericServerStreamMethod; + + @io.grpc.stub.annotations.RpcMethod( + fullMethodName = SERVICE_NAME + '/' + "genericServerStream", + requestType = triple.Request.class, + responseType = triple.Response.class, + methodType = io.grpc.MethodDescriptor.MethodType.SERVER_STREAMING) + public static io.grpc.MethodDescriptor getGenericServerStreamMethod() { + io.grpc.MethodDescriptor getGenericServerStreamMethod; + if ((getGenericServerStreamMethod = GenericServiceGrpc.getGenericServerStreamMethod) == null) { + synchronized (GenericServiceGrpc.class) { + if ((getGenericServerStreamMethod = GenericServiceGrpc.getGenericServerStreamMethod) == null) { + GenericServiceGrpc.getGenericServerStreamMethod = getGenericServerStreamMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.SERVER_STREAMING) + .setFullMethodName(generateFullMethodName(SERVICE_NAME, "genericServerStream")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + triple.Request.getDefaultInstance())) + .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + triple.Response.getDefaultInstance())) + .setSchemaDescriptor(new GenericServiceMethodDescriptorSupplier("genericServerStream")) + .build(); + } + } + } + return getGenericServerStreamMethod; + } + /** * Creates a new async stub that supports all call types for the service */ @@ -101,6 +163,20 @@ public void generic(triple.Request request, io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getGenericMethod(), responseObserver); } + /** + */ + public io.grpc.stub.StreamObserver genericBiStream( + io.grpc.stub.StreamObserver responseObserver) { + return io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall(getGenericBiStreamMethod(), responseObserver); + } + + /** + */ + public void genericServerStream(triple.Request request, + io.grpc.stub.StreamObserver responseObserver) { + io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getGenericServerStreamMethod(), responseObserver); + } + @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() { return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) .addMethod( @@ -110,6 +186,20 @@ public void generic(triple.Request request, triple.Request, triple.Response>( this, METHODID_GENERIC))) + .addMethod( + getGenericBiStreamMethod(), + io.grpc.stub.ServerCalls.asyncBidiStreamingCall( + new MethodHandlers< + triple.Request, + triple.Response>( + this, METHODID_GENERIC_BI_STREAM))) + .addMethod( + getGenericServerStreamMethod(), + io.grpc.stub.ServerCalls.asyncServerStreamingCall( + new MethodHandlers< + triple.Request, + triple.Response>( + this, METHODID_GENERIC_SERVER_STREAM))) .build(); } } @@ -135,6 +225,22 @@ public void generic(triple.Request request, io.grpc.stub.ClientCalls.asyncUnaryCall( getChannel().newCall(getGenericMethod(), getCallOptions()), request, responseObserver); } + + /** + */ + public io.grpc.stub.StreamObserver genericBiStream( + io.grpc.stub.StreamObserver responseObserver) { + return io.grpc.stub.ClientCalls.asyncBidiStreamingCall( + getChannel().newCall(getGenericBiStreamMethod(), getCallOptions()), responseObserver); + } + + /** + */ + public void genericServerStream(triple.Request request, + io.grpc.stub.StreamObserver responseObserver) { + io.grpc.stub.ClientCalls.asyncServerStreamingCall( + getChannel().newCall(getGenericServerStreamMethod(), getCallOptions()), request, responseObserver); + } } /** @@ -157,6 +263,14 @@ public triple.Response generic(triple.Request request) { return io.grpc.stub.ClientCalls.blockingUnaryCall( getChannel(), getGenericMethod(), getCallOptions(), request); } + + /** + */ + public java.util.Iterator genericServerStream( + triple.Request request) { + return io.grpc.stub.ClientCalls.blockingServerStreamingCall( + getChannel(), getGenericServerStreamMethod(), getCallOptions(), request); + } } /** @@ -183,6 +297,8 @@ public com.google.common.util.concurrent.ListenableFuture gener } private static final int METHODID_GENERIC = 0; + private static final int METHODID_GENERIC_SERVER_STREAM = 1; + private static final int METHODID_GENERIC_BI_STREAM = 2; private static final class MethodHandlers implements io.grpc.stub.ServerCalls.UnaryMethod, @@ -205,6 +321,10 @@ public void invoke(Req request, io.grpc.stub.StreamObserver responseObserv serviceImpl.generic((triple.Request) request, (io.grpc.stub.StreamObserver) responseObserver); break; + case METHODID_GENERIC_SERVER_STREAM: + serviceImpl.genericServerStream((triple.Request) request, + (io.grpc.stub.StreamObserver) responseObserver); + break; default: throw new AssertionError(); } @@ -215,6 +335,9 @@ public void invoke(Req request, io.grpc.stub.StreamObserver responseObserv public io.grpc.stub.StreamObserver invoke( io.grpc.stub.StreamObserver responseObserver) { switch (methodId) { + case METHODID_GENERIC_BI_STREAM: + return (io.grpc.stub.StreamObserver) serviceImpl.genericBiStream( + (io.grpc.stub.StreamObserver) responseObserver); default: throw new AssertionError(); } @@ -267,6 +390,8 @@ public static io.grpc.ServiceDescriptor getServiceDescriptor() { serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME) .setSchemaDescriptor(new GenericServiceFileDescriptorSupplier()) .addMethod(getGenericMethod()) + .addMethod(getGenericBiStreamMethod()) + .addMethod(getGenericServerStreamMethod()) .build(); } } diff --git a/remoting/remoting-triple/build/generated/source/proto/main/java/triple/SofaGenericServiceTriple.java b/remoting/remoting-triple/build/generated/source/proto/main/java/triple/SofaGenericServiceTriple.java index 0eb60e668..94f2b52bc 100644 --- a/remoting/remoting-triple/build/generated/source/proto/main/java/triple/SofaGenericServiceTriple.java +++ b/remoting/remoting-triple/build/generated/source/proto/main/java/triple/SofaGenericServiceTriple.java @@ -50,6 +50,23 @@ public void generic(triple.Request request, io.grpc.stub.StreamObserver genericServerStream(triple.Request request) { + return blockingStub + .withDeadlineAfter(timeout, TimeUnit.MILLISECONDS) + .genericServerStream(request); + } + + public void genericServerStream(triple.Request request, io.grpc.stub.StreamObserver responseObserver) { + stub + .withDeadlineAfter(timeout, TimeUnit.MILLISECONDS) + .genericServerStream(request, responseObserver); + } + + public io.grpc.stub.StreamObserver genericBiStream(io.grpc.stub.StreamObserver responseObserver) { + return stub + .withDeadlineAfter(timeout, TimeUnit.MILLISECONDS) + .genericBiStream(responseObserver); + } } public static SofaGenericServiceStub getSofaStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions,int timeout) { @@ -71,6 +88,14 @@ default public com.google.common.util.concurrent.ListenableFuture responseObserver); + default public java.util.Iterator genericServerStream(triple.Request request) { + throw new UnsupportedOperationException("No need to override this method, extend XxxImplBase and override all methods it allows."); + } + + public void genericServerStream(triple.Request request, io.grpc.stub.StreamObserver responseObserver); + + public io.grpc.stub.StreamObserver genericBiStream(io.grpc.stub.StreamObserver responseObserver); + } public static abstract class GenericServiceImplBase implements io.grpc.BindableService, IGenericService { @@ -89,12 +114,25 @@ public final triple.Response generic(triple.Request request) { @java.lang.Override public final com.google.common.util.concurrent.ListenableFuture genericAsync(triple.Request request) { throw new UnsupportedOperationException("No need to override this method, extend XxxImplBase and override all methods it allows."); + } + + @java.lang.Override + public final java.util.Iterator genericServerStream(triple.Request request) { + throw new UnsupportedOperationException("No need to override this method, extend XxxImplBase and override all methods it allows."); } public void generic(triple.Request request, io.grpc.stub.StreamObserver responseObserver) { asyncUnimplementedUnaryCall(triple.GenericServiceGrpc.getGenericMethod(), responseObserver); } + public io.grpc.stub.StreamObserver genericBiStream( + io.grpc.stub.StreamObserver responseObserver) { + return asyncUnimplementedStreamingCall(triple.GenericServiceGrpc.getGenericBiStreamMethod(), responseObserver); + } + public void genericServerStream(triple.Request request, + io.grpc.stub.StreamObserver responseObserver) { + asyncUnimplementedUnaryCall(triple.GenericServiceGrpc.getGenericServerStreamMethod(), responseObserver); + } @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() { return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) @@ -105,10 +143,26 @@ public void generic(triple.Request request, triple.Request, triple.Response>( proxiedImpl, METHODID_GENERIC))) + .addMethod( + triple.GenericServiceGrpc.getGenericBiStreamMethod(), + asyncBidiStreamingCall( + new MethodHandlers< + triple.Request, + triple.Response>( + proxiedImpl, METHODID_GENERIC_BI_STREAM))) + .addMethod( + triple.GenericServiceGrpc.getGenericServerStreamMethod(), + asyncServerStreamingCall( + new MethodHandlers< + triple.Request, + triple.Response>( + proxiedImpl, METHODID_GENERIC_SERVER_STREAM))) .build(); } } private static final int METHODID_GENERIC = 0; + private static final int METHODID_GENERIC_BI_STREAM = 1; + private static final int METHODID_GENERIC_SERVER_STREAM = 2; private static final class MethodHandlers implements @@ -137,6 +191,10 @@ public void invoke(Req request, io.grpc.stub.StreamObserver serviceImpl.generic((triple.Request) request, (io.grpc.stub.StreamObserver) responseObserver); break; + case METHODID_GENERIC_SERVER_STREAM: + serviceImpl.genericServerStream((triple.Request) request, + (io.grpc.stub.StreamObserver) responseObserver); + break; default: throw new java.lang.AssertionError(); } @@ -148,6 +206,10 @@ public void invoke(Req request, io.grpc.stub.StreamObserver invoke(io.grpc.stub.StreamObserver responseObserver) { switch (methodId) { + case METHODID_GENERIC_BI_STREAM: + return (io.grpc.stub.StreamObserver + ) serviceImpl.genericBiStream( + (io.grpc.stub.StreamObserver) responseObserver); default: throw new java.lang.AssertionError(); } diff --git a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/message/triple/stream/ClientStreamObserverAdapter.java b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/message/triple/stream/ClientStreamObserverAdapter.java new file mode 100644 index 000000000..be61a5060 --- /dev/null +++ b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/message/triple/stream/ClientStreamObserverAdapter.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alipay.sofa.rpc.message.triple.stream; + +import com.alipay.sofa.rpc.codec.Serializer; +import com.alipay.sofa.rpc.codec.SerializerFactory; +import com.alipay.sofa.rpc.core.exception.RpcErrorType; +import com.alipay.sofa.rpc.core.exception.SofaRpcException; +import com.alipay.sofa.rpc.transport.ByteArrayWrapperByteBuf; +import com.alipay.sofa.rpc.transport.StreamHandler; +import io.grpc.stub.StreamObserver; + +/** + * ClientStreamObserverAdapter. + */ +public class ClientStreamObserverAdapter implements StreamObserver { + + private final StreamHandler streamHandler; + + private final Serializer serializer; + + private volatile Class returnType; + + public ClientStreamObserverAdapter(StreamHandler streamHandler, byte serializeType) { + this.streamHandler = streamHandler; + this.serializer = SerializerFactory.getSerializer(serializeType); + } + + @Override + public void onNext(triple.Response response) { + byte[] responseDate = response.getData().toByteArray(); + Object appResponse = null; + String returnTypeName = response.getType(); + if (responseDate != null && responseDate.length > 0) { + if (returnType == null && !returnTypeName.isEmpty()) { + try { + returnType = Class.forName(returnTypeName); + } catch (ClassNotFoundException e) { + throw new SofaRpcException(RpcErrorType.CLIENT_SERIALIZE, "Can not find return type :" + returnType); + } + } + appResponse = serializer.decode(new ByteArrayWrapperByteBuf(responseDate), returnType, null); + } + + streamHandler.onMessage(appResponse); + } + + @Override + public void onError(Throwable t) { + streamHandler.onException(t); + } + + @Override + public void onCompleted() { + streamHandler.onFinish(); + } +} diff --git a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/message/triple/stream/ResponseSerializeStreamHandler.java b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/message/triple/stream/ResponseSerializeStreamHandler.java new file mode 100644 index 000000000..51edcedf7 --- /dev/null +++ b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/message/triple/stream/ResponseSerializeStreamHandler.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alipay.sofa.rpc.message.triple.stream; + +import com.alipay.sofa.rpc.codec.Serializer; +import com.alipay.sofa.rpc.codec.SerializerFactory; +import com.alipay.sofa.rpc.transport.StreamHandler; +import com.alipay.sofa.rpc.utils.TripleExceptionUtils; +import com.google.protobuf.ByteString; +import io.grpc.stub.StreamObserver; +import triple.Response; + +/** + * Response serialize stream handler. + */ +public class ResponseSerializeStreamHandler implements StreamHandler { + + private final StreamObserver streamObserver; + + private Serializer serializer; + + private String serializeType; + + public ResponseSerializeStreamHandler(StreamObserver streamObserver, String serializeType) { + this.streamObserver = streamObserver; + serializer = SerializerFactory.getSerializer(serializeType); + this.serializeType = serializeType; + } + + @Override + public void onMessage(T message) { + Response.Builder builder = Response.newBuilder(); + builder.setType(message.getClass().getName()); + builder.setSerializeType(serializeType); + builder.setData(ByteString.copyFrom(serializer.encode(message, null).array())); + + streamObserver.onNext(builder.build()); + } + + @Override + public void onFinish() { + streamObserver.onCompleted(); + } + + @Override + public void onException(Throwable throwable) { + streamObserver.onError(TripleExceptionUtils.asStatusRuntimeException(throwable)); + } + +} diff --git a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/server/triple/GenericServiceImpl.java b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/server/triple/GenericServiceImpl.java index dff254195..722c0796b 100644 --- a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/server/triple/GenericServiceImpl.java +++ b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/server/triple/GenericServiceImpl.java @@ -18,8 +18,12 @@ import com.alipay.sofa.rpc.codec.Serializer; import com.alipay.sofa.rpc.codec.SerializerFactory; +import com.alipay.sofa.rpc.common.RpcConstants; +import com.alipay.sofa.rpc.common.cache.ReflectCache; import com.alipay.sofa.rpc.common.utils.ClassTypeUtils; import com.alipay.sofa.rpc.common.utils.ClassUtils; +import com.alipay.sofa.rpc.common.utils.StringUtils; +import com.alipay.sofa.rpc.config.ProviderConfig; import com.alipay.sofa.rpc.core.exception.RpcErrorType; import com.alipay.sofa.rpc.core.exception.SofaRpcException; import com.alipay.sofa.rpc.core.exception.SofaRpcRuntimeException; @@ -27,8 +31,10 @@ import com.alipay.sofa.rpc.core.response.SofaResponse; import com.alipay.sofa.rpc.log.Logger; import com.alipay.sofa.rpc.log.LoggerFactory; +import com.alipay.sofa.rpc.message.triple.stream.ResponseSerializeStreamHandler; import com.alipay.sofa.rpc.tracer.sofatracer.TracingContextKey; import com.alipay.sofa.rpc.transport.ByteArrayWrapperByteBuf; +import com.alipay.sofa.rpc.transport.StreamHandler; import com.google.protobuf.ByteString; import com.google.protobuf.ProtocolStringList; import io.grpc.Context; @@ -40,6 +46,8 @@ import java.lang.reflect.Method; import java.util.List; +import static com.alipay.sofa.rpc.common.RpcOptions.DEFAULT_SERIALIZATION; + /** * @author zhaowang * @version : GenericServiceImpl.java, v 0.1 2020年05月27日 9:19 下午 zhaowang Exp $ @@ -50,9 +58,12 @@ public class GenericServiceImpl extends SofaGenericServiceTriple.GenericServiceI protected UniqueIdInvoker invoker; - public GenericServiceImpl(UniqueIdInvoker invoker) { + private ProviderConfig providerConfig; + + public GenericServiceImpl(UniqueIdInvoker invoker, ProviderConfig serverConfig) { super(); this.invoker = invoker; + this.providerConfig = serverConfig; } @Override @@ -63,22 +74,14 @@ public void generic(Request request, StreamObserver responseObserver) SofaRequest sofaRequest = TracingContextKey.getKeySofaRequest().get(Context.current()); String methodName = sofaRequest.getMethodName(); try { - ClassLoader serviceClassLoader = invoker.getServiceClassLoader(sofaRequest); - Thread.currentThread().setContextClassLoader(serviceClassLoader); - - Method declaredMethod = invoker.getDeclaredMethod(sofaRequest, request); + Method declaredMethod = setClassLoaderAndGetRequestMethod(sofaRequest, request, + RpcConstants.INVOKER_TYPE_UNARY); if (declaredMethod == null) { throw new SofaRpcException(RpcErrorType.SERVER_NOT_FOUND_INVOKER, "Cannot find invoke method " + methodName); } - Class[] argTypes = getArgTypes(request); Serializer serializer = SerializerFactory.getSerializer(request.getSerializeType()); - Object[] invokeArgs = getInvokeArgs(request, argTypes, serializer); - - // fill sofaRequest - sofaRequest.setMethod(declaredMethod); - sofaRequest.setMethodArgs(invokeArgs); - sofaRequest.setMethodArgSigs(ClassTypeUtils.getTypeStrs(argTypes, true)); + setUnaryOrServerRequestParams(sofaRequest, request, methodName, serializer, declaredMethod, false); SofaResponse response = invoker.invoke(sofaRequest); Object ret = getAppResponse(declaredMethod, response); @@ -98,6 +101,156 @@ public void generic(Request request, StreamObserver responseObserver) } } + @Override + public StreamObserver genericBiStream(StreamObserver responseObserver) { + Method serviceMethod = getBidirectionalStreamRequestMethod(); + //通过上下文创建请求 + SofaRequest sofaRequest = TracingContextKey.getKeySofaRequest().get(Context.current()); + + if (serviceMethod == null) { + throw new SofaRpcException(RpcErrorType.SERVER_NOT_FOUND_INVOKER, "Cannot find invoke method " + + sofaRequest.getMethodName()); + } + String methodName = serviceMethod.getName(); + try { + ResponseSerializeStreamHandler serverResponseHandler = new ResponseSerializeStreamHandler(responseObserver, + getSerialization()); + + setBidirectionalStreamRequestParams(sofaRequest, serviceMethod, serverResponseHandler); + + SofaResponse sofaResponse = invoker.invoke(sofaRequest); + + StreamHandler clientHandler = (StreamHandler) sofaResponse.getAppResponse(); + + return new StreamObserver() { + volatile Serializer serializer = null; + + volatile Class[] argTypes = null; + + @Override + public void onNext(Request request) { + checkInitialize(request); + Object message = getInvokeArgs(request, argTypes, serializer, false)[0]; + clientHandler.onMessage(message); + } + + private void checkInitialize(Request request) { + if (serializer == null && argTypes == null) { + synchronized (this) { + if (serializer == null && argTypes == null) { + serializer = SerializerFactory.getSerializer(request.getSerializeType()); + argTypes = getArgTypes(request, false); + } + } + } + } + + @Override + public void onError(Throwable t) { + clientHandler.onException(t); + } + + @Override + public void onCompleted() { + clientHandler.onFinish(); + } + }; + } catch (Exception e) { + LOGGER.error("Invoke " + methodName + " error:", e); + throw new SofaRpcRuntimeException(e); + } finally { + Thread.currentThread().setContextClassLoader(Thread.currentThread().getContextClassLoader()); + } + } + + @Override + public void genericServerStream(Request request, StreamObserver responseObserver) { + SofaRequest sofaRequest = TracingContextKey.getKeySofaRequest().get(Context.current()); + Method serviceMethod = setClassLoaderAndGetRequestMethod(sofaRequest, request,RpcConstants.INVOKER_TYPE_SERVER_STREAMING); + + if (serviceMethod == null) { + throw new SofaRpcException(RpcErrorType.SERVER_NOT_FOUND_INVOKER, "Cannot find invoke method " + + sofaRequest.getMethodName()); + } + + String methodName = serviceMethod.getName(); + try { + Serializer serializer = SerializerFactory.getSerializer(request.getSerializeType()); + + setUnaryOrServerRequestParams(sofaRequest, request, methodName, serializer, serviceMethod, true); + sofaRequest.getMethodArgs()[0] = new ResponseSerializeStreamHandler<>(responseObserver, getSerialization()); + + invoker.invoke(sofaRequest); + } catch (Exception e) { + LOGGER.error("Invoke " + methodName + " error:", e); + throw new SofaRpcRuntimeException(e); + } finally { + Thread.currentThread().setContextClassLoader(Thread.currentThread().getContextClassLoader()); + } + } + + private Method setClassLoaderAndGetRequestMethod(SofaRequest sofaRequest, Request request, String callType) { + ClassLoader serviceClassLoader = invoker.getServiceClassLoader(sofaRequest); + Thread.currentThread().setContextClassLoader(serviceClassLoader); + return invoker.getDeclaredMethod(sofaRequest, request, callType); + } + + private Method getBidirectionalStreamRequestMethod() { + SofaRequest sofaRequest = TracingContextKey.getKeySofaRequest().get(Context.current()); + String uniqueName = invoker.getServiceUniqueName(sofaRequest); + return ReflectCache.getOverloadMethodCache(uniqueName, sofaRequest.getMethodName(), + new String[] { StreamHandler.class.getCanonicalName() }); + } + + /** + * Resolve method invoke args into request for unary or server-streaming calls. + * + * @param sofaRequest SofaRequest + * @param request Request + * @param methodName MethodName + * @param serializer Serializer + * @param declaredMethod Target invoke method + */ + private void setUnaryOrServerRequestParams(SofaRequest sofaRequest, Request request, String methodName, + Serializer serializer, Method declaredMethod, boolean isServerStreamCall) { + setClassLoader(sofaRequest); + if (declaredMethod == null) { + throw new SofaRpcException(RpcErrorType.SERVER_NOT_FOUND_INVOKER, "Cannot find invoke method " + + methodName); + } + Class[] argTypes = getArgTypes(request, isServerStreamCall); + Object[] invokeArgs = getInvokeArgs(request, argTypes, serializer, isServerStreamCall); + + // fill sofaRequest + sofaRequest.setMethod(declaredMethod); + sofaRequest.setMethodArgs(invokeArgs); + sofaRequest.setMethodArgSigs(ClassTypeUtils.getTypeStrs(argTypes, true)); + } + + /** + * Resolve method invoke args into request for bidirectional stream calls. + * + * @param sofaRequest SofaRequest + * @param serviceMethod Target service method + * @param serverStreamPushHandler The StreamHandler used to push a message to a client. It's a wrapper for {@link StreamObserver}, and encode method return value to {@link Response}. + */ + private void setBidirectionalStreamRequestParams(SofaRequest sofaRequest, Method serviceMethod, + StreamHandler serverStreamPushHandler) { + setClassLoader(sofaRequest); + + Class[] argTypes = new Class[] { StreamHandler.class }; + Object[] invokeArgs = new Object[] { serverStreamPushHandler }; + + sofaRequest.setMethod(serviceMethod); + sofaRequest.setMethodArgs(invokeArgs); + sofaRequest.setMethodArgSigs(ClassTypeUtils.getTypeStrs(argTypes, true)); + } + + private void setClassLoader(SofaRequest sofaRequest) { + ClassLoader serviceClassLoader = invoker.getServiceClassLoader(sofaRequest); + Thread.currentThread().setContextClassLoader(serviceClassLoader); + } + private Object getAppResponse(Method method, SofaResponse response) { if (response.isError()) { throw new SofaRpcException(RpcErrorType.SERVER_UNDECLARED_ERROR, response.getErrorMsg()); @@ -113,26 +266,71 @@ private Object getAppResponse(Method method, SofaResponse response) { return ret; } - private Class[] getArgTypes(Request request) { + /** + * Get argument types from request. + * @param request original request + * @param addStreamHandler Whether add StreamHandler as the first method param. + *

+ * For server stream call, the StreamHandler won't be transported. + * To make the argument list conform to the method definition, we need to add it as first method param manually. + * + * @return param types of target method + */ + private Class[] getArgTypes(Request request, boolean addStreamHandler) { ProtocolStringList argTypesList = request.getArgTypesList(); - int size = argTypesList.size(); + + int size = addStreamHandler ? argTypesList.size() + 1 : argTypesList.size(); Class[] argTypes = new Class[size]; - for (int i = 0; i < size; i++) { - String typeName = argTypesList.get(i); + + if (addStreamHandler) { + argTypes[0] = StreamHandler.class; + } + for (int i = addStreamHandler ? 1 : 0; i < size; i++) { + String typeName; + if (addStreamHandler) { + typeName = argTypesList.get(i - 1); + } else { + typeName = argTypesList.get(i); + } argTypes[i] = ClassTypeUtils.getClass(typeName); } return argTypes; } - private Object[] getInvokeArgs(Request request, Class[] argTypes, Serializer serializer) { + /** + * Get arguments from request. + * @param addStreamHandler if addStreamHandler == true, the first arg will be left blank and set later. + * + * @return params of target method. + */ + private Object[] getInvokeArgs(Request request, Class[] argTypes, Serializer serializer, boolean addStreamHandler) { List argsList = request.getArgsList(); - Object[] args = new Object[argsList.size()]; + int size = addStreamHandler ? argsList.size() + 1 : argsList.size(); + int start = addStreamHandler ? 1 : 0; + Object[] args = new Object[size]; - for (int i = 0; i < argsList.size(); i++) { - byte[] data = argsList.get(i).toByteArray(); + for (int i = start; i < size; i++) { + byte[] data; + if (addStreamHandler) { + data = argsList.get(i - 1).toByteArray(); + } else { + data = argsList.get(i).toByteArray(); + } args[i] = serializer.decode(new ByteArrayWrapperByteBuf(data), argTypes[i], null); } return args; } + + private String getSerialization() { + String serialization = providerConfig.getSerialization(); + if (StringUtils.isBlank(serialization)) { + serialization = getDefaultSerialization(); + } + return serialization; + } + + private String getDefaultSerialization() { + return DEFAULT_SERIALIZATION; + } } \ No newline at end of file diff --git a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/server/triple/TripleServer.java b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/server/triple/TripleServer.java index f3f312cde..f21fe5e39 100644 --- a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/server/triple/TripleServer.java +++ b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/server/triple/TripleServer.java @@ -40,7 +40,6 @@ import com.alipay.sofa.rpc.utils.SofaProtoUtils; import io.grpc.BindableService; import io.grpc.MethodDescriptor; -import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptors; import io.grpc.ServerMethodDefinition; import io.grpc.ServerServiceDefinition; @@ -58,7 +57,9 @@ import triple.Response; import java.lang.reflect.Method; + import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Set; @@ -68,6 +69,7 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; + import static io.grpc.MethodDescriptor.generateFullMethodName; /** @@ -112,7 +114,7 @@ public class TripleServer implements Server { /** * The mapping relationship between service name and unique id invoker */ - protected Map invokerMap = new ConcurrentHashMap<>(); + protected Map invokerMap = new ConcurrentHashMap<>(); /** * invoker count @@ -296,7 +298,7 @@ private ServerServiceDefinition getServerServiceDefinition(ProviderConfig provid BindableService bindableService = (BindableService) providerConfig.getRef(); serviceDef = bindableService.bindService(); } else { - GenericServiceImpl genericService = new GenericServiceImpl(uniqueIdInvoker); + GenericServiceImpl genericService = new GenericServiceImpl(uniqueIdInvoker,providerConfig); genericService.setProxiedImpl(genericService); serviceDef = buildSofaServiceDef(genericService, providerConfig); } @@ -328,11 +330,9 @@ private void setBindableProxiedImpl(ProviderConfig providerConfig, Invoker invok private ServerServiceDefinition buildSofaServiceDef(GenericServiceImpl genericService, ProviderConfig providerConfig) { ServerServiceDefinition templateDefinition = genericService.bindService(); - ServerCallHandler templateHandler = (ServerCallHandler) templateDefinition - .getMethods().iterator().next().getServerCallHandler(); List> methodDescriptor = getMethodDescriptor(providerConfig); - List> methodDefs = getMethodDefinitions(templateHandler, - methodDescriptor); + List> methodDefs = createMethodDefinition(templateDefinition,methodDescriptor); + // Bind the actual service to a specific method in the generic service ServerServiceDefinition.Builder builder = ServerServiceDefinition.builder(getServiceDescriptor( templateDefinition, providerConfig, methodDescriptor)); for (ServerMethodDefinition methodDef : methodDefs) { @@ -341,15 +341,27 @@ private ServerServiceDefinition buildSofaServiceDef(GenericServiceImpl genericSe return builder.build(); } - private List> getMethodDefinitions(ServerCallHandler templateHandler,List> methodDescriptors) { - List> result = new ArrayList<>(); - for (MethodDescriptor methodDescriptor : methodDescriptors) { - ServerMethodDefinition serverMethodDefinition = ServerMethodDefinition.create(methodDescriptor, templateHandler); - result.add(serverMethodDefinition); - } - return result; + private List> createMethodDefinition(ServerServiceDefinition geneticServiceDefinition, List> serviceMethods){ + Collection> genericServiceMethods = geneticServiceDefinition.getMethods(); + List> serverMethodDefinitions = new ArrayList<>(); + //Map ture service method to certain generic service method. + for (ServerMethodDefinition genericMethods : genericServiceMethods){ + for(MethodDescriptor methodDescriptor : serviceMethods){ + + if(methodDescriptor.getType().equals(genericMethods.getMethodDescriptor().getType())){ + + ServerMethodDefinition genericMeth = (ServerMethodDefinition) genericMethods; + + serverMethodDefinitions.add( + ServerMethodDefinition.create(methodDescriptor, genericMeth.getServerCallHandler()) + ); + } + } + } + return serverMethodDefinitions; } + private ServiceDescriptor getServiceDescriptor(ServerServiceDefinition template, ProviderConfig providerConfig, List> methodDescriptors) { String serviceName = providerConfig.getInterfaceId(); @@ -365,9 +377,11 @@ private ServiceDescriptor getServiceDescriptor(ServerServiceDefinition template, private List> getMethodDescriptor(ProviderConfig providerConfig) { List> result = new ArrayList<>(); Set methodNames = SofaProtoUtils.getMethodNames(providerConfig.getInterfaceId()); + for (String name : methodNames) { + MethodDescriptor.MethodType methodType = SofaProtoUtils.mapGrpcCallType(providerConfig.getMethodCallType(name)); MethodDescriptor methodDescriptor = MethodDescriptor.newBuilder() - .setType(MethodDescriptor.MethodType.UNARY) + .setType(methodType) .setFullMethodName(generateFullMethodName(providerConfig.getInterfaceId(), name)) .setSampledToLocalTracing(true) .setRequestMarshaller(ProtoUtils.marshaller( diff --git a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/server/triple/UniqueIdInvoker.java b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/server/triple/UniqueIdInvoker.java index 15b615bd3..0c0bf2bb0 100644 --- a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/server/triple/UniqueIdInvoker.java +++ b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/server/triple/UniqueIdInvoker.java @@ -16,6 +16,7 @@ */ package com.alipay.sofa.rpc.server.triple; +import com.alipay.sofa.rpc.common.RpcConstants; import com.alipay.sofa.rpc.common.cache.ReflectCache; import com.alipay.sofa.rpc.common.utils.StringUtils; import com.alipay.sofa.rpc.config.ConfigUniqueNameGenerator; @@ -27,10 +28,13 @@ import com.alipay.sofa.rpc.core.response.SofaResponse; import com.alipay.sofa.rpc.invoke.Invoker; import com.alipay.sofa.rpc.server.ProviderProxyInvoker; +import com.alipay.sofa.rpc.transport.StreamHandler; import triple.Request; import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; @@ -174,14 +178,20 @@ public ClassLoader getServiceClassLoader(SofaRequest sofaRequest) { return ReflectCache.getServiceClassLoader(uniqueName); } - public Method getDeclaredMethod(SofaRequest sofaRequest, Request request) { + public Method getDeclaredMethod(SofaRequest sofaRequest, Request request, String callType) { String uniqueName = this.getServiceUniqueName(sofaRequest); - return ReflectCache.getOverloadMethodCache(uniqueName, sofaRequest.getMethodName(), request - .getArgTypesList() + List argTypesList = request.getArgTypesList(); + if(RpcConstants.INVOKER_TYPE_SERVER_STREAMING.equals(callType)){ + List a = new ArrayList<>(argTypesList.size()+1); + a.add(0, StreamHandler.class.getCanonicalName()); + a.addAll(argTypesList); + argTypesList = a; + } + return ReflectCache.getOverloadMethodCache(uniqueName, sofaRequest.getMethodName(), argTypesList .toArray(new String[0])); } - private String getServiceUniqueName(SofaRequest sofaRequest) { + public String getServiceUniqueName(SofaRequest sofaRequest) { this.readLock.lock(); try { Invoker invoker = this.findInvoker(sofaRequest.getInterfaceName(), getUniqueIdFromInvokeContext()); diff --git a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/transport/triple/TripleClientInvoker.java b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/transport/triple/TripleClientInvoker.java index 7079f3f13..c660a4f8f 100644 --- a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/transport/triple/TripleClientInvoker.java +++ b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/transport/triple/TripleClientInvoker.java @@ -38,14 +38,16 @@ import com.alipay.sofa.rpc.filter.FilterChain; import com.alipay.sofa.rpc.log.Logger; import com.alipay.sofa.rpc.log.LoggerFactory; +import com.alipay.sofa.rpc.message.MessageBuilder; import com.alipay.sofa.rpc.message.ResponseFuture; +import com.alipay.sofa.rpc.message.triple.stream.ClientStreamObserverAdapter; import com.alipay.sofa.rpc.message.triple.TripleResponseFuture; import com.alipay.sofa.rpc.transport.ByteArrayWrapperByteBuf; +import com.alipay.sofa.rpc.transport.StreamHandler; +import com.alipay.sofa.rpc.utils.SofaProtoUtils; +import com.alipay.sofa.rpc.utils.TripleExceptionUtils; import com.google.protobuf.ByteString; -import io.grpc.CallOptions; -import io.grpc.Channel; -import io.grpc.MethodDescriptor; -import io.grpc.Status; +import io.grpc.*; import io.grpc.protobuf.ProtoUtils; import io.grpc.stub.ClientCalls; import io.grpc.stub.StreamObserver; @@ -69,22 +71,22 @@ * @date 2018.12.15 7:06 PM */ public class TripleClientInvoker implements TripleInvoker { - private final static Logger LOGGER = LoggerFactory.getLogger(TripleClientInvoker.class); + private final static Logger LOGGER = LoggerFactory.getLogger(TripleClientInvoker.class); private final static String DEFAULT_SERIALIZATION = SERIALIZE_HESSIAN2; - protected Channel channel; + protected Channel channel; - protected ConsumerConfig consumerConfig; + protected ConsumerConfig consumerConfig; protected ProviderInfo providerInfo; - protected Method sofaStub; + protected Method sofaStub; - protected boolean useGeneric; + protected boolean useGeneric; - private Serializer serializer; - private String serialization; + private Serializer serializer; + private String serialization; private Map methodMap = new ConcurrentHashMap<>(); @@ -106,6 +108,23 @@ public TripleClientInvoker(ConsumerConfig consumerConfig, ProviderInfo providerI } } + public static Request getRequest(SofaRequest sofaRequest, String serialization, Serializer serializer, int trueParamStart) { + Request.Builder builder = Request.newBuilder(); + builder.setSerializeType(serialization); + + String[] methodArgSigs = sofaRequest.getMethodArgSigs(); + Object[] methodArgs = sofaRequest.getMethodArgs(); + + for (int i = trueParamStart; i < methodArgSigs.length; i++) { + Object arg = methodArgs[i]; + ByteString argByteString = ByteString.copyFrom(serializer.encode(arg, null).array()); + builder.addArgs(argByteString); + builder.addArgTypes(methodArgSigs[i]); + } + return builder.build(); + } + + private void cacheCommonData(ConsumerConfig consumerConfig) { String serialization = consumerConfig.getSerialization(); if (StringUtils.isBlank(serialization)) { @@ -121,35 +140,132 @@ protected String getDefaultSerialization() { @Override public SofaResponse invoke(SofaRequest sofaRequest, int timeout) - throws Exception { - if (!useGeneric) { - SofaResponse sofaResponse = new SofaResponse(); - Object stub = sofaStub.invoke(null, channel, buildCustomCallOptions(sofaRequest, timeout), - timeout); - final Method method = sofaRequest.getMethod(); - Object appResponse = method.invoke(stub, sofaRequest.getMethodArgs()[0]); - sofaResponse.setAppResponse(appResponse); - return sofaResponse; + throws Exception { + + MethodDescriptor.MethodType callType = mapCallType(sofaRequest); + + if(!useGeneric){ + return stubCall(sofaRequest,timeout); + } else if (callType.equals(MethodDescriptor.MethodType.UNARY)) { + return unaryCall(sofaRequest, timeout); } else { - MethodDescriptor methodDescriptor = getMethodDescriptor(sofaRequest); - Request request = getRequest(sofaRequest, serialization, serializer); - Response response = (Response) ClientCalls.blockingUnaryCall(channel, methodDescriptor, + return streamCall(sofaRequest, timeout, callType); + } + } + + private MethodDescriptor.MethodType mapCallType(SofaRequest sofaRequest) { + String sofaCallType = sofaRequest.getInvokeType(); + switch (sofaCallType) { + case RpcConstants.INVOKER_TYPE_BI_STREAMING: + return MethodDescriptor.MethodType.BIDI_STREAMING; + case RpcConstants.INVOKER_TYPE_CLIENT_STREAMING: + return MethodDescriptor.MethodType.CLIENT_STREAMING; + case RpcConstants.INVOKER_TYPE_SERVER_STREAMING: + return MethodDescriptor.MethodType.SERVER_STREAMING; + default: + return MethodDescriptor.MethodType.UNARY; + } + } + + private SofaResponse streamCall(SofaRequest sofaRequest, int timeout, MethodDescriptor.MethodType callType) { + switch (callType) { + case BIDI_STREAMING: + return binaryStreamCall(sofaRequest, timeout); + case CLIENT_STREAMING: + return clientStreamCall(sofaRequest, timeout); + case SERVER_STREAMING: + return serverStreamCall(sofaRequest, timeout); + default: + throw new SofaRpcException(RpcErrorType.CLIENT_CALL_TYPE, "Unknown stream call type:" + callType); + } + } + + + private SofaResponse unaryCall(SofaRequest sofaRequest, int timeout) throws Exception{ + MethodDescriptor methodDescriptor = getMethodDescriptor(sofaRequest); + Request request = getRequest(sofaRequest, serialization, serializer, 0); + Response response = (Response) ClientCalls.blockingUnaryCall(channel, methodDescriptor, buildCustomCallOptions(sofaRequest, timeout), request); - SofaResponse sofaResponse = new SofaResponse(); - byte[] responseDate = response.getData().toByteArray(); - Class returnType = sofaRequest.getMethod().getReturnType(); - if (returnType != void.class) { - if (responseDate != null && responseDate.length > 0) { - Serializer responseSerializer = SerializerFactory.getSerializer(response.getSerializeType()); - Object appResponse = responseSerializer.decode(new ByteArrayWrapperByteBuf(responseDate), returnType, null); - sofaResponse.setAppResponse(appResponse); - } + SofaResponse sofaResponse = new SofaResponse(); + byte[] responseDate = response.getData().toByteArray(); + Class returnType = sofaRequest.getMethod().getReturnType(); + if (returnType != void.class) { + if (responseDate != null && responseDate.length > 0) { + Serializer responseSerializer = SerializerFactory.getSerializer(response.getSerializeType()); + Object appResponse = responseSerializer.decode(new ByteArrayWrapperByteBuf(responseDate), returnType, null); + sofaResponse.setAppResponse(appResponse); } - - return sofaResponse; } + return sofaResponse; + } + private SofaResponse stubCall(SofaRequest sofaRequest, int timeout) throws Exception{ + SofaResponse sofaResponse = new SofaResponse(); + Object stub = sofaStub.invoke(null, channel, buildCustomCallOptions(sofaRequest, timeout), + timeout); + final Method method = sofaRequest.getMethod(); + Object appResponse = method.invoke(stub, sofaRequest.getMethodArgs()[0]); + sofaResponse.setAppResponse(appResponse); + return sofaResponse; + } + + private SofaResponse binaryStreamCall(SofaRequest sofaRequest, int timeout) { + StreamHandler streamHandler = (StreamHandler) sofaRequest.getMethodArgs()[0]; + + MethodDescriptor methodDescriptor = getMethodDescriptor(sofaRequest); + ClientCall call = channel.newCall(methodDescriptor, buildCustomCallOptions(sofaRequest, timeout)); + + StreamObserver observer = ClientCalls.asyncBidiStreamingCall( + call, + new ClientStreamObserverAdapter( + streamHandler, + sofaRequest.getSerializeType() + ) + ); + StreamHandler handler = new StreamHandler() { + @Override + public void onMessage(Object message) { + SofaRequest request = MessageBuilder.copyEmptyRequest(sofaRequest); + Object[] args = new Object[]{message}; + request.setMethodArgs(args); + request.setMethodArgSigs(rebuildTrueRequestArgSigs(args)); + Request req = getRequest(request, serialization, serializer, 0); + observer.onNext(req); + } + + @Override + public void onFinish() { + observer.onCompleted(); + } + + @Override + public void onException(Throwable throwable) { + observer.onError(TripleExceptionUtils.asStatusRuntimeException(throwable)); + } + }; + SofaResponse sofaResponse = new SofaResponse(); + sofaResponse.setAppResponse(handler); + return sofaResponse; + } + + private SofaResponse clientStreamCall(SofaRequest sofaRequest, int timeout) { + return binaryStreamCall(sofaRequest, timeout); + } + + private SofaResponse serverStreamCall(SofaRequest sofaRequest, int timeout) { + StreamHandler streamHandler = (StreamHandler) sofaRequest.getMethodArgs()[0]; + + MethodDescriptor methodDescriptor = getMethodDescriptor(sofaRequest); + ClientCall call = channel.newCall(methodDescriptor, buildCustomCallOptions(sofaRequest, timeout)); + + Request req = getRequest(sofaRequest, serialization, serializer, 1); + + ClientStreamObserverAdapter responseObserver = new ClientStreamObserverAdapter(streamHandler, sofaRequest.getSerializeType()); + + ClientCalls.asyncServerStreamingCall(call, req, responseObserver); + + return new SofaResponse(); } @Override @@ -199,7 +315,7 @@ public void onCompleted() { }); } else { MethodDescriptor methodDescriptor = getMethodDescriptor(sofaRequest); - Request request = getRequest(sofaRequest, serialization, serializer); + Request request = getRequest(sofaRequest, serialization, serializer, 0); ClientCalls.asyncUnaryCall(channel.newCall(methodDescriptor, buildCustomCallOptions(sofaRequest, timeout)), request, new StreamObserver() { @Override public void onNext(Object o) { @@ -220,6 +336,24 @@ public void onCompleted() { return future; } + /** + * Build arg sigs for stream calls. + * + * @param requestArgs request args + * @return arg sigs, arg.getClass().getName(). + */ + private String[] rebuildTrueRequestArgSigs(Object[] requestArgs) { + String[] classes = new String[requestArgs.length]; + for (int k = 0; k < requestArgs.length; k++) { + if (requestArgs[k] != null) { + classes[k] = requestArgs[k].getClass().getName(); + } else { + classes[k] = void.class.getName(); + } + } + return classes; + } + private void processSuccess(boolean needDecode, RpcInternalContext context, SofaRequest sofaRequest, Object o, SofaResponseCallback sofaResponseCallback, TripleResponseFuture future, ClassLoader classLoader) { ClassLoader oldCl = Thread.currentThread().getContextClassLoader(); try { @@ -299,7 +433,7 @@ private void processError(RpcInternalContext context, SofaRequest sofaRequest, T Status status = Status.fromThrowable(throwable); if (status.getCode() == Status.Code.UNKNOWN) { sofaResponseCallback.onAppException(throwable, sofaRequest.getMethodName(), sofaRequest); - }else { + } else { sofaResponseCallback.onSofaException(new SofaRpcException(RpcErrorType.UNKNOWN, status.getCause()), sofaRequest.getMethodName(), sofaRequest); } } else { @@ -342,41 +476,28 @@ protected void pickupBaggage(RpcInternalContext context, SofaResponse response) } } - private MethodDescriptor getMethodDescriptor(SofaRequest sofaRequest) { + private MethodDescriptor getMethodDescriptor(SofaRequest sofaRequest) { String serviceName = sofaRequest.getInterfaceName(); String methodName = sofaRequest.getMethodName(); MethodDescriptor.Marshaller requestMarshaller = ProtoUtils.marshaller(Request.getDefaultInstance()); MethodDescriptor.Marshaller responseMarshaller = ProtoUtils.marshaller(Response.getDefaultInstance()); String fullMethodName = generateFullMethodName(serviceName, methodName); - MethodDescriptor methodDescriptor = MethodDescriptor + + MethodDescriptor.Builder builder = MethodDescriptor .newBuilder() - .setType(MethodDescriptor.MethodType.UNARY) .setFullMethodName(fullMethodName) .setSampledToLocalTracing(true) .setRequestMarshaller((MethodDescriptor.Marshaller) requestMarshaller) - .setResponseMarshaller((MethodDescriptor.Marshaller) responseMarshaller) - .build(); - return methodDescriptor; - } - - public static Request getRequest(SofaRequest sofaRequest, String serialization, Serializer serializer) { - Request.Builder builder = Request.newBuilder(); - builder.setSerializeType(serialization); + .setResponseMarshaller((MethodDescriptor.Marshaller) responseMarshaller); - String[] methodArgSigs = sofaRequest.getMethodArgSigs(); - Object[] methodArgs = sofaRequest.getMethodArgs(); - - for (int i = 0; i < methodArgSigs.length; i++) { - Object arg = methodArgs[i]; - ByteString argByteString = ByteString.copyFrom(serializer.encode(arg, null).array()); - builder.addArgs(argByteString); - builder.addArgTypes(methodArgSigs[i]); - } + MethodDescriptor.MethodType callType = SofaProtoUtils.mapGrpcCallType(sofaRequest.getInvokeType()); + builder.setType(callType); return builder.build(); } /** * set some custom info + * * @param sofaRequest * @param timeout * @return diff --git a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/utils/SofaProtoUtils.java b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/utils/SofaProtoUtils.java index eec457315..9afccbeef 100644 --- a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/utils/SofaProtoUtils.java +++ b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/utils/SofaProtoUtils.java @@ -18,14 +18,20 @@ import com.alipay.sofa.rpc.common.utils.ClassUtils; import com.alipay.sofa.rpc.config.ConsumerConfig; +import com.alipay.sofa.rpc.core.exception.RpcErrorType; +import com.alipay.sofa.rpc.core.exception.SofaRpcException; + import io.grpc.BindableService; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.MethodDescriptor; import java.lang.reflect.Method; import java.util.HashSet; import java.util.Set; +import static com.alipay.sofa.rpc.common.RpcConstants.*; + /** * @author zhaowang * @version : SofaProtoUtils.java, v 0.1 2020年05月27日 7:25 下午 zhaowang Exp $ @@ -62,4 +68,22 @@ public static boolean checkIfUseGeneric(ConsumerConfig consumerConfig) { return true; } + public static MethodDescriptor.MethodType mapGrpcCallType(String callType) { + switch (callType) { + case INVOKER_TYPE_ONEWAY: + case INVOKER_TYPE_FUTURE: + case INVOKER_TYPE_CALLBACK: + case INVOKER_TYPE_SYNC: + return MethodDescriptor.MethodType.UNARY; + case INVOKER_TYPE_BI_STREAMING: + return MethodDescriptor.MethodType.BIDI_STREAMING; + case INVOKER_TYPE_CLIENT_STREAMING: + return MethodDescriptor.MethodType.CLIENT_STREAMING; + case INVOKER_TYPE_SERVER_STREAMING: + return MethodDescriptor.MethodType.SERVER_STREAMING; + default: + throw new SofaRpcException(RpcErrorType.CLIENT_CALL_TYPE, "Unsupported invoke type:" + callType); + } + } + } \ No newline at end of file diff --git a/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/utils/TripleExceptionUtils.java b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/utils/TripleExceptionUtils.java new file mode 100644 index 000000000..5387e3af3 --- /dev/null +++ b/remoting/remoting-triple/src/main/java/com/alipay/sofa/rpc/utils/TripleExceptionUtils.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alipay.sofa.rpc.utils; + +import io.grpc.Status; +import io.grpc.StatusRuntimeException; + +public class TripleExceptionUtils { + + public static StatusRuntimeException asStatusRuntimeException(Throwable t) { + if (t != null) { + return Status.fromThrowable(t).withDescription(t.getMessage()).withCause(t.getCause()).asRuntimeException(); + } else { + return Status.UNKNOWN.withDescription("Error message is null.").asRuntimeException(); + } + } + +} diff --git a/remoting/remoting-triple/src/main/proto/transformer.proto b/remoting/remoting-triple/src/main/proto/transformer.proto index be08b097d..2152268db 100644 --- a/remoting/remoting-triple/src/main/proto/transformer.proto +++ b/remoting/remoting-triple/src/main/proto/transformer.proto @@ -9,6 +9,11 @@ option java_outer_classname = "GenericProto"; service GenericService { rpc generic (Request) returns (Response) {} + + rpc genericBiStream (stream Request) returns (stream Response){} + + rpc genericServerStream(Request) returns (stream Response){} + } message Request { diff --git a/remoting/remoting-triple/src/test/java/com/alipay/sofa/rpc/server/triple/GenericServiceImplTest.java b/remoting/remoting-triple/src/test/java/com/alipay/sofa/rpc/server/triple/GenericServiceImplTest.java index e0c566ec3..a70c2199f 100644 --- a/remoting/remoting-triple/src/test/java/com/alipay/sofa/rpc/server/triple/GenericServiceImplTest.java +++ b/remoting/remoting-triple/src/test/java/com/alipay/sofa/rpc/server/triple/GenericServiceImplTest.java @@ -63,7 +63,7 @@ public GenericServiceImplTest(){ ProviderProxyInvoker invoker = new ProviderProxyInvoker(providerConfig); UniqueIdInvoker uniqueIdInvoker = new UniqueIdInvoker(); uniqueIdInvoker.registerInvoker(providerConfig, invoker); - genericService = new GenericServiceImpl(uniqueIdInvoker); + genericService = new GenericServiceImpl(uniqueIdInvoker,providerConfig); responseObserver = new MockStreamObserver<>(); } @@ -136,7 +136,7 @@ private Object getReturnValue(Method method) { private Request buildRequest(Method method, Object[] args) { Class[] parameterTypes = method.getParameterTypes(); SofaRequest sofaRequest = MessageBuilder.buildSofaRequest(HelloService.class, method, parameterTypes, args); - Request request = TripleClientInvoker.getRequest(sofaRequest, serialization, serializer); + Request request = TripleClientInvoker.getRequest(sofaRequest, serialization, serializer, 0); Context context = Context.current().withValue(TracingContextKey.getKeySofaRequest(), sofaRequest); context.attach(); return request; diff --git a/test/test-integration/src/main/proto/helloworld.proto b/test/test-integration/src/main/proto/helloworld.proto index a237b8e2f..be12afabc 100644 --- a/test/test-integration/src/main/proto/helloworld.proto +++ b/test/test-integration/src/main/proto/helloworld.proto @@ -25,6 +25,8 @@ package helloworld; service Greeter { // Sends a greeting rpc SayHello (HelloRequest) returns (HelloReply) {} + + rpc SayHelloBinary (stream HelloRequest) returns (stream HelloReply){} } // The request message containing the user's name. diff --git a/test/test-integration/src/test/java/com/alipay/sofa/rpc/test/triple/GreeterImpl.java b/test/test-integration/src/test/java/com/alipay/sofa/rpc/test/triple/GreeterImpl.java index ec3e2f334..b917c1c75 100644 --- a/test/test-integration/src/test/java/com/alipay/sofa/rpc/test/triple/GreeterImpl.java +++ b/test/test-integration/src/test/java/com/alipay/sofa/rpc/test/triple/GreeterImpl.java @@ -51,4 +51,24 @@ public void sayHello(HelloRequest req, StreamObserver responseObserv responseObserver.onNext(reply); responseObserver.onCompleted(); } + + @Override + public StreamObserver sayHelloBinary(StreamObserver responseObserver) { + return new StreamObserver() { + + @Override + public void onNext(HelloRequest value) { + responseObserver.onNext(HelloReply.newBuilder().setMessage(value.getName()) + .build()); + } + + @Override + public void onError(Throwable t) { + } + + @Override + public void onCompleted() { + } + }; + } } \ No newline at end of file diff --git a/test/test-integration/src/test/java/com/alipay/sofa/rpc/test/triple/TripleServerTest.java b/test/test-integration/src/test/java/com/alipay/sofa/rpc/test/triple/TripleServerTest.java index 3c608ca86..1693fa45a 100644 --- a/test/test-integration/src/test/java/com/alipay/sofa/rpc/test/triple/TripleServerTest.java +++ b/test/test-integration/src/test/java/com/alipay/sofa/rpc/test/triple/TripleServerTest.java @@ -33,11 +33,15 @@ import io.grpc.examples.helloworld.HelloReply; import io.grpc.examples.helloworld.HelloRequest; import io.grpc.examples.helloworld.SofaGreeterTriple; +import io.grpc.stub.StreamObserver; import org.junit.After; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + /** * @author leizhiyuan */ @@ -208,6 +212,55 @@ public String messageSize(String msg, int responseSize) { Assert.assertEquals(reply, "Hello! world"); } + @Test + public void testBiStream() throws InterruptedException { + ApplicationConfig applicationConfig = new ApplicationConfig().setAppName("triple-server"); + + int port = 50052; + + ServerConfig serverConfig = new ServerConfig() + .setProtocol(RpcConstants.PROTOCOL_TYPE_TRIPLE) + .setPort(port); + + ProviderConfig providerConfig = new ProviderConfig() + .setApplication(applicationConfig) + .setBootstrap(RpcConstants.PROTOCOL_TYPE_TRIPLE) + .setInterfaceId(SofaGreeterTriple.IGreeter.class.getName()) + .setRef(new GreeterImpl()) + .setServer(serverConfig); + + providerConfig.export(); + + ConsumerConfig consumerConfig = new ConsumerConfig(); + consumerConfig.setInterfaceId(SofaGreeterTriple.IGreeter.class.getName()) + .setProtocol(RpcConstants.PROTOCOL_TYPE_TRIPLE) + .setDirectUrl("tri://127.0.0.1:" + port); + + SofaGreeterTriple.IGreeter greeterBlockingStub = consumerConfig.refer(); + + HelloRequest request = HelloRequest.newBuilder().setName("Hello world!").build(); + CountDownLatch countDownLatch = new CountDownLatch(1); + StreamObserver requestStreamObserver = greeterBlockingStub + .sayHelloBinary(new StreamObserver() { + @Override + public void onNext(HelloReply value) { + Assert.assertEquals(value.getMessage(), request.getName()); + countDownLatch.countDown(); + } + + @Override + public void onError(Throwable t) { + } + + @Override + public void onCompleted() { + } + }); + requestStreamObserver.onNext(request); + Assert.assertTrue(countDownLatch.await(20, TimeUnit.SECONDS)); + requestStreamObserver.onCompleted(); + } + @Test //同步调用,直连 public void testSyncTimeout() { diff --git a/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/ClientRequest.java b/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/ClientRequest.java new file mode 100644 index 000000000..efe29dc09 --- /dev/null +++ b/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/ClientRequest.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alipay.sofa.rpc.triple.stream; + +public class ClientRequest { + private String meg; + + private int count; + + public ClientRequest(String meg, int count) { + this.meg = meg; + this.count = count; + } + + public String getMsg() { + return meg; + } + + public int getCount() { + return count; + } +} diff --git a/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/HelloService.java b/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/HelloService.java new file mode 100644 index 000000000..f8fa147be --- /dev/null +++ b/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/HelloService.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alipay.sofa.rpc.triple.stream; + +import com.alipay.sofa.rpc.transport.StreamHandler; + +public interface HelloService { + + String CMD_TRIGGER_STREAM_FINISH = "finish"; + + String CMD_TRIGGER_STEAM_ERROR = "error"; + + String ERROR_MSG = "error msg"; + + void sayHello(); + + void sayHello(String msg); + + String sayHelloUnary(String message); + + StreamHandler sayHelloBiStream(StreamHandler streamHandler); + + void sayHelloServerStream(StreamHandler streamHandler, ClientRequest clientRequest); + +} diff --git a/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/HelloServiceImpl.java b/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/HelloServiceImpl.java new file mode 100644 index 000000000..491281f61 --- /dev/null +++ b/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/HelloServiceImpl.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alipay.sofa.rpc.triple.stream; + +import com.alipay.sofa.rpc.transport.StreamHandler; + +public class HelloServiceImpl implements HelloService { + + @Override + public void sayHello() { + System.out.println("Get hello from consumer!"); + } + + @Override + public void sayHello(String msg) { + System.out.println("Get " + msg + "from consumer"); + } + + @Override + public String sayHelloUnary(String message) { + System.out.println("Get hello from consumer and try response..."); + return "Hello too, " + message; + } + + @Override + public StreamHandler sayHelloBiStream(StreamHandler streamHandler) { + return new ClientRequestEchoHandler(streamHandler); + } + + @Override + public void sayHelloServerStream(StreamHandler streamHandler, ClientRequest clientRequest) { + streamHandler.onMessage(new ServerResponse(clientRequest.getMsg(), clientRequest.getCount())); + streamHandler.onMessage(new ServerResponse(clientRequest.getMsg(), clientRequest.getCount() + 1)); + streamHandler.onMessage(new ServerResponse(clientRequest.getMsg(), clientRequest.getCount() + 2)); + streamHandler.onMessage(new ServerResponse(clientRequest.getMsg(), clientRequest.getCount() + 3)); + streamHandler.onMessage(new ServerResponse(clientRequest.getMsg(), clientRequest.getCount() + 4)); + if (clientRequest.getMsg().equals(HelloService.CMD_TRIGGER_STEAM_ERROR)) { + streamHandler.onException(new RuntimeException(HelloService.ERROR_MSG)); + } else { + streamHandler.onFinish(); + } + } + + static class ClientRequestEchoHandler implements StreamHandler { + + StreamHandler respHandler; + + public ClientRequestEchoHandler(StreamHandler respHandler) { + this.respHandler = respHandler; + } + + @Override + public void onMessage(ClientRequest clientRequest) { + if (clientRequest.getMsg().equals(CMD_TRIGGER_STREAM_FINISH)) { + respHandler.onFinish(); + } else if (clientRequest.getMsg().equals(CMD_TRIGGER_STEAM_ERROR)) { + respHandler.onException(new RuntimeException(ERROR_MSG)); + } + else { + respHandler.onMessage(new ServerResponse(clientRequest.getMsg(), clientRequest.getCount())); + } + } + + @Override + public void onFinish() { + respHandler.onFinish(); + } + + @Override + public void onException(Throwable throwable) { + respHandler.onMessage(new ServerResponse("Received exception:" + throwable.getMessage(), -2)); + respHandler.onException(throwable); + throwable.printStackTrace(); + } + }; +} diff --git a/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/ServerResponse.java b/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/ServerResponse.java new file mode 100644 index 000000000..53eb0d970 --- /dev/null +++ b/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/ServerResponse.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alipay.sofa.rpc.triple.stream; + +public class ServerResponse { + private String msg; + + private int count; + + public ServerResponse(String msg, int count) { + this.msg = msg; + this.count = count; + } + + public String getMsg() { + return msg; + } + + public int getCount() { + return count; + } +} diff --git a/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/TripleGenericStreamTest.java b/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/TripleGenericStreamTest.java new file mode 100644 index 000000000..1c5572403 --- /dev/null +++ b/test/test-integration/src/test/java/com/alipay/sofa/rpc/triple/stream/TripleGenericStreamTest.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alipay.sofa.rpc.triple.stream; + +import com.alipay.sofa.rpc.config.ConsumerConfig; +import com.alipay.sofa.rpc.config.ProviderConfig; +import com.alipay.sofa.rpc.config.ServerConfig; +import com.alipay.sofa.rpc.context.RpcInternalContext; +import com.alipay.sofa.rpc.context.RpcInvokeContext; +import com.alipay.sofa.rpc.context.RpcRunningState; +import com.alipay.sofa.rpc.context.RpcRuntimeContext; +import com.alipay.sofa.rpc.transport.StreamHandler; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class TripleGenericStreamTest { + + static final String HELLO_MSG = "Hello, world!"; + ConsumerConfig consumerConfig; + ProviderConfig providerConfig; + HelloService helloServiceInst; + + ConsumerConfig consumerRefer() { + ConsumerConfig consumerConfig = new ConsumerConfig() + .setInterfaceId(HelloService.class.getName()) + .setProtocol("tri") + .setDirectUrl("triple://127.0.0.1:12200"); + consumerConfig.refer(); + return consumerConfig; + } + + ProviderConfig providerExport() { + ServerConfig serverConfig = new ServerConfig() + .setProtocol("tri") + .setPort(12200) + .setDaemon(false); + + helloServiceInst = Mockito.spy(new HelloServiceImpl()); + + ProviderConfig providerConfig = new ProviderConfig() + .setInterfaceId(HelloService.class.getName()) + .setRef(helloServiceInst) + .setServer(serverConfig); + + providerConfig.export(); + return providerConfig; + } + + @Before + public void bootStrap() { + RpcRunningState.setUnitTestMode(true); + providerConfig = providerExport(); + consumerConfig = consumerRefer(); + } + + @After + public void shutdown() { + consumerConfig.unRefer(); + providerConfig.unExport(); + RpcRuntimeContext.destroy(); + RpcInternalContext.removeContext(); + RpcInvokeContext.removeContext(); + } + + public void testTripleBiStream(boolean endWithException) throws InterruptedException { + + int requestTimes = 5; + CountDownLatch countDownLatch = new CountDownLatch(requestTimes + 1); + + AtomicBoolean receivedFinish = new AtomicBoolean(false); + AtomicBoolean receivedException = new AtomicBoolean(false); + + HelloService helloServiceRef = consumerConfig.refer(); + + StreamHandler streamHandler = helloServiceRef + .sayHelloBiStream(new StreamHandler() { + final AtomicInteger requestCount = new AtomicInteger(0); + + @Override + public void onMessage(ServerResponse message) { + Assert.assertEquals(requestCount.getAndIncrement(), message.getCount()); + Assert.assertEquals(HELLO_MSG, message.getMsg()); + countDownLatch.countDown(); + } + + @Override + public void onFinish() { + receivedFinish.set(true); + countDownLatch.countDown(); + } + + @Override + public void onException(Throwable throwable) { + Assert.assertTrue(throwable.getMessage().contains(HelloService.ERROR_MSG)); + receivedException.set(true); + countDownLatch.countDown(); + } + }); + for (int k = 0; k < requestTimes; k++) { + streamHandler.onMessage(new ClientRequest(HELLO_MSG, k)); + } + if (!endWithException) { + streamHandler.onMessage(new ClientRequest(HelloService.CMD_TRIGGER_STREAM_FINISH, -2)); + Assert.assertTrue(countDownLatch.await(20, TimeUnit.SECONDS)); + Assert.assertTrue(receivedFinish.get()); + streamHandler.onFinish(); + Assert.assertFalse(receivedException.get()); + Assert.assertThrows(Throwable.class, () -> streamHandler.onMessage(new ClientRequest("", 123))); + } else { + streamHandler.onMessage(new ClientRequest(HelloService.CMD_TRIGGER_STEAM_ERROR, -2)); + Assert.assertTrue(countDownLatch.await(20, TimeUnit.SECONDS)); + streamHandler.onException(new RuntimeException(HelloService.ERROR_MSG)); + Assert.assertThrows(Throwable.class,()->streamHandler.onMessage(new ClientRequest(HELLO_MSG,0))); + Assert.assertFalse(receivedFinish.get()); + Assert.assertTrue(receivedException.get()); + } + verify(helloServiceInst, times(1)).sayHelloBiStream(any()); + } + + @Test + public void testTripleBiStreamException() throws InterruptedException { + testTripleBiStream(true); + } + + @Test + public void testTripleBiStreamFinish() throws InterruptedException { + testTripleBiStream(false); + } + + public void testTripleServerStream(boolean endWithException) throws InterruptedException { + HelloService helloServiceRef = consumerConfig.refer(); + AtomicInteger count = new AtomicInteger(0); + int responseTimes = 5; + CountDownLatch countDownLatch = new CountDownLatch(responseTimes + 1); + AtomicBoolean responseFinished = new AtomicBoolean(false); + AtomicBoolean responseException = new AtomicBoolean(false); + + helloServiceRef.sayHelloServerStream(new StreamHandler() { + @Override + public void onMessage(ServerResponse message) { + Assert.assertEquals(endWithException ? HelloService.CMD_TRIGGER_STEAM_ERROR : HELLO_MSG, + message.getMsg()); + Assert.assertEquals(count.getAndIncrement(), message.getCount()); + countDownLatch.countDown(); + } + + @Override + public void onFinish() { + responseFinished.set(true); + countDownLatch.countDown(); + } + + @Override + public void onException(Throwable throwable) { + Assert.assertTrue(throwable.getMessage().contains(HelloService.ERROR_MSG)); + responseException.set(true); + countDownLatch.countDown(); + } + }, new ClientRequest(endWithException ? HelloService.CMD_TRIGGER_STEAM_ERROR : HELLO_MSG, 0)); + + Assert.assertTrue(countDownLatch.await(20, TimeUnit.SECONDS)); + if (endWithException) { + Assert.assertTrue(responseException.get()); + Assert.assertFalse(responseFinished.get()); + } else { + Assert.assertTrue(responseFinished.get()); + Assert.assertFalse(responseException.get()); + } + Assert.assertEquals(responseTimes, count.get()); + verify(helloServiceInst, times(1)).sayHelloServerStream(any(), any()); + } + + @Test + public void testTripleServerStreamFinish() throws InterruptedException { + testTripleServerStream(false); + } + + @Test + public void testTripleServerStreamException() throws InterruptedException { + testTripleServerStream(true); + } +}