diff --git a/src/spaceone/core/pygrpc/client.py b/src/spaceone/core/pygrpc/client.py index 5d5343a..f0b50f4 100644 --- a/src/spaceone/core/pygrpc/client.py +++ b/src/spaceone/core/pygrpc/client.py @@ -2,10 +2,12 @@ import types import grpc from google.protobuf.json_format import ParseDict -from google.protobuf.message_factory import MessageFactory +from google.protobuf.message_factory import MessageFactory, GetMessageClass from google.protobuf.descriptor_pool import DescriptorPool from google.protobuf.descriptor import ServiceDescriptor, MethodDescriptor -from grpc_reflection.v1alpha.proto_reflection_descriptor_database import ProtoReflectionDescriptorDatabase +from grpc_reflection.v1alpha.proto_reflection_descriptor_database import ( + ProtoReflectionDescriptorDatabase, +) from spaceone.core.error import * _MAX_RETRIES = 2 @@ -14,24 +16,32 @@ class _ClientInterceptor( - grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): - + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): def __init__(self, options: dict, channel_key: str, request_map: dict): self._request_map = request_map self._channel_key = channel_key - self.metadata = options.get('metadata', {}) + self.metadata = options.get("metadata", {}) def _check_message(self, client_call_details, request_or_iterator, is_stream): if client_call_details.method in self._request_map: if is_stream: if not isinstance(request_or_iterator, types.GeneratorType): - raise Exception("Stream method must be specified as a generator type.") + raise Exception( + "Stream method must be specified as a generator type." + ) - return self._generate_message(request_or_iterator, client_call_details.method) + return self._generate_message( + request_or_iterator, client_call_details.method + ) else: - return self._make_message(request_or_iterator, client_call_details.method) + return self._make_message( + request_or_iterator, client_call_details.method + ) return request_or_iterator @@ -50,17 +60,17 @@ def _check_error(self, response): if isinstance(response, Exception): details = response.details() status_code = response.code().name - if details.startswith('ERROR_'): - details_split = details.split(':', 1) + if details.startswith("ERROR_"): + details_split = details.split(":", 1) if len(details_split) == 2: error_code, error_message = details_split else: error_code = details_split[0] error_message = details - if status_code == 'PERMISSION_DENIED': + if status_code == "PERMISSION_DENIED": raise ERROR_PERMISSION_DENIED() - elif status_code == 'UNAUTHENTICATED': + elif status_code == "UNAUTHENTICATED": raise ERROR_AUTHENTICATE_FAILURE(message=error_message) else: e = ERROR_INTERNAL_API(message=error_message) @@ -70,13 +80,15 @@ def _check_error(self, response): else: error_message = response.details() - if status_code == 'PERMISSION_DENIED': + if status_code == "PERMISSION_DENIED": raise ERROR_PERMISSION_DENIED() - elif status_code == 'PERMISSION_DENIED': + elif status_code == "PERMISSION_DENIED": raise ERROR_AUTHENTICATE_FAILURE(message=error_message) - elif status_code == 'UNAVAILABLE': - e = ERROR_GRPC_CONNECTION(channel=self._channel_key, message=error_message) - e.meta['channel'] = self._channel_key + elif status_code == "UNAVAILABLE": + e = ERROR_GRPC_CONNECTION( + channel=self._channel_key, message=error_message + ) + e.meta["channel"] = self._channel_key raise e else: e = ERROR_INTERNAL_API(message=error_message) @@ -92,12 +104,16 @@ def _generate_response(self, response_iterator): except Exception as e: self._check_error(e) - def _retry_call(self, continuation, client_call_details, request_or_iterator, is_stream): + def _retry_call( + self, continuation, client_call_details, request_or_iterator, is_stream + ): retries = 0 while True: try: - response_or_iterator = continuation(client_call_details, request_or_iterator) + response_or_iterator = continuation( + client_call_details, request_or_iterator + ) if is_stream: response_or_iterator = self._generate_response(response_or_iterator) @@ -107,84 +123,142 @@ def _retry_call(self, continuation, client_call_details, request_or_iterator, is return response_or_iterator except Exception as e: - if e.error_code == 'ERROR_GRPC_CONNECTION': + if e.error_code == "ERROR_GRPC_CONNECTION": if retries >= _MAX_RETRIES: - channel = e.meta.get('channel') + channel = e.meta.get("channel") if channel in _GRPC_CHANNEL: - _LOGGER.error(f'Disconnect gRPC Endpoint. (channel = {channel})') + _LOGGER.error( + f"Disconnect gRPC Endpoint. (channel = {channel})" + ) del _GRPC_CHANNEL[channel] raise e else: - _LOGGER.debug(f'Retry gRPC Call: reason = {e.message}, retry = {retries + 1}') + _LOGGER.debug( + f"Retry gRPC Call: reason = {e.message}, retry = {retries + 1}" + ) else: raise e retries += 1 - def _intercept_call(self, continuation, client_call_details, - request_or_iterator, is_request_stream, is_response_stream): - new_request_or_iterator = self. _check_message( - client_call_details, request_or_iterator, is_request_stream) - - return self._retry_call(continuation, client_call_details, - new_request_or_iterator, is_response_stream) + def _intercept_call( + self, + continuation, + client_call_details, + request_or_iterator, + is_request_stream, + is_response_stream, + ): + new_request_or_iterator = self._check_message( + client_call_details, request_or_iterator, is_request_stream + ) + + return self._retry_call( + continuation, + client_call_details, + new_request_or_iterator, + is_response_stream, + ) def intercept_unary_unary(self, continuation, client_call_details, request): - return self._intercept_call(continuation, client_call_details, request, False, False) + return self._intercept_call( + continuation, client_call_details, request, False, False + ) def intercept_unary_stream(self, continuation, client_call_details, request): - return self._intercept_call(continuation, client_call_details, request, False, True) + return self._intercept_call( + continuation, client_call_details, request, False, True + ) - def intercept_stream_unary(self, continuation, client_call_details, request_iterator): - return self._intercept_call(continuation, client_call_details, request_iterator, True, False) + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): + return self._intercept_call( + continuation, client_call_details, request_iterator, True, False + ) - def intercept_stream_stream(self, continuation, client_call_details, request_iterator): - return self._intercept_call(continuation, client_call_details, request_iterator, True, True) + def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): + return self._intercept_call( + continuation, client_call_details, request_iterator, True, True + ) class _GRPCStub(object): - - def __init__(self, desc_pool: DescriptorPool, service_desc: ServiceDescriptor, channel: grpc.Channel): + def __init__( + self, + desc_pool: DescriptorPool, + service_desc: ServiceDescriptor, + channel: grpc.Channel, + ): self._desc_pool = desc_pool for method_desc in service_desc.methods: self._bind_grpc_method(service_desc, method_desc, channel) - def _bind_grpc_method(self, service_desc: ServiceDescriptor, method_desc: MethodDescriptor, channel: grpc.Channel): + def _bind_grpc_method( + self, + service_desc: ServiceDescriptor, + method_desc: MethodDescriptor, + channel: grpc.Channel, + ): method_name = method_desc.name - method_key = f'/{service_desc.full_name}/{method_name}' - request_desc = self._desc_pool.FindMessageTypeByName(method_desc.input_type.full_name) - request_message_desc = MessageFactory(self._desc_pool).GetPrototype(request_desc) - response_desc = self._desc_pool.FindMessageTypeByName(method_desc.output_type.full_name) - response_message_desc = MessageFactory(self._desc_pool).GetPrototype(response_desc) + method_key = f"/{service_desc.full_name}/{method_name}" + request_desc = self._desc_pool.FindMessageTypeByName( + method_desc.input_type.full_name + ) + # request_message_desc = MessageFactory(self._desc_pool).GetPrototype(request_desc) + request_message_desc = GetMessageClass(request_desc) + + response_desc = self._desc_pool.FindMessageTypeByName( + method_desc.output_type.full_name + ) + # response_message_desc = MessageFactory(self._desc_pool).GetPrototype(response_desc) + response_message_desc = GetMessageClass(response_desc) if method_desc.client_streaming and method_desc.server_streaming: - setattr(self, method_name, channel.stream_stream( - method_key, - request_serializer=request_message_desc.SerializeToString, - response_deserializer=response_message_desc.FromString - )) + setattr( + self, + method_name, + channel.stream_stream( + method_key, + request_serializer=request_message_desc.SerializeToString, + response_deserializer=response_message_desc.FromString, + ), + ) elif method_desc.client_streaming and not method_desc.server_streaming: - setattr(self, method_name, channel.stream_unary( - method_key, - request_serializer=request_message_desc.SerializeToString, - response_deserializer=response_message_desc.FromString - )) + setattr( + self, + method_name, + channel.stream_unary( + method_key, + request_serializer=request_message_desc.SerializeToString, + response_deserializer=response_message_desc.FromString, + ), + ) elif not method_desc.client_streaming and method_desc.server_streaming: - setattr(self, method_name, channel.unary_stream( - method_key, - request_serializer=request_message_desc.SerializeToString, - response_deserializer=response_message_desc.FromString - )) + setattr( + self, + method_name, + channel.unary_stream( + method_key, + request_serializer=request_message_desc.SerializeToString, + response_deserializer=response_message_desc.FromString, + ), + ) else: - setattr(self, method_name, channel.unary_unary( - method_key, - request_serializer=request_message_desc.SerializeToString, - response_deserializer=response_message_desc.FromString - )) + setattr( + self, + method_name, + channel.unary_unary( + method_key, + request_serializer=request_message_desc.SerializeToString, + response_deserializer=response_message_desc.FromString, + ), + ) class GRPCClient(object): - def __init__(self, channel, options, channel_key): self._request_map = {} self._api_resources = {} @@ -193,7 +267,9 @@ def __init__(self, channel, options, channel_key): self._desc_pool = DescriptorPool(self._reflection_db) self._init_grpc_reflection() - _client_interceptor = _ClientInterceptor(options, channel_key, self._request_map) + _client_interceptor = _ClientInterceptor( + options, channel_key, self._request_map + ) _intercept_channel = grpc.intercept_channel(channel, _client_interceptor) self._bind_grpc_stub(_intercept_channel) @@ -206,9 +282,12 @@ def _init_grpc_reflection(self): service_desc: ServiceDescriptor = self._desc_pool.FindServiceByName(service) service_name = service_desc.name for method_desc in service_desc.methods: - method_key = f'/{service}/{method_desc.name}' - request_desc = self._desc_pool.FindMessageTypeByName(method_desc.input_type.full_name) - self._request_map[method_key] = MessageFactory(self._desc_pool).GetPrototype(request_desc) + method_key = f"/{service}/{method_desc.name}" + request_desc = self._desc_pool.FindMessageTypeByName( + method_desc.input_type.full_name + ) + # self._request_map[method_key] = MessageFactory(self._desc_pool).GetPrototype(request_desc) + self._request_map[method_key] = GetMessageClass(request_desc) if service_desc.name not in self._api_resources: self._api_resources[service_name] = [] @@ -219,7 +298,11 @@ def _bind_grpc_stub(self, intercept_channel: grpc.Channel): for service in self._reflection_db.get_services(): service_desc: ServiceDescriptor = self._desc_pool.FindServiceByName(service) - setattr(self, service_desc.name, _GRPCStub(self._desc_pool, service_desc, intercept_channel)) + setattr( + self, + service_desc.name, + _GRPCStub(self._desc_pool, service_desc, intercept_channel), + ) def _create_secure_channel(endpoint, options): @@ -245,8 +328,8 @@ def client(endpoint=None, ssl_enabled=False, max_message_length=None, **client_o options = [] if max_message_length: - options.append(('grpc.max_send_message_length', max_message_length)) - options.append(('grpc.max_receive_message_length', max_message_length)) + options.append(("grpc.max_send_message_length", max_message_length)) + options.append(("grpc.max_receive_message_length", max_message_length)) if ssl_enabled: channel = _create_secure_channel(endpoint, options) @@ -256,12 +339,14 @@ def client(endpoint=None, ssl_enabled=False, max_message_length=None, **client_o try: grpc.channel_ready_future(channel).result(timeout=3) except Exception as e: - raise ERROR_GRPC_CONNECTION(channel=endpoint, message='Channel is not ready.') + raise ERROR_GRPC_CONNECTION( + channel=endpoint, message="Channel is not ready." + ) try: _GRPC_CHANNEL[endpoint] = GRPCClient(channel, client_opts, endpoint) except Exception as e: - if hasattr(e, 'details'): + if hasattr(e, "details"): raise ERROR_GRPC_CONNECTION(channel=endpoint, message=e.details()) else: raise ERROR_GRPC_CONNECTION(channel=endpoint, message=str(e)) @@ -271,12 +356,16 @@ def client(endpoint=None, ssl_enabled=False, max_message_length=None, **client_o def get_grpc_method(uri_info): try: - conn = client(endpoint=uri_info['endpoint'], ssl_enabled=uri_info['ssl_enabled']) - return getattr(getattr(conn, uri_info['service']), uri_info['method']) + conn = client( + endpoint=uri_info["endpoint"], ssl_enabled=uri_info["ssl_enabled"] + ) + return getattr(getattr(conn, uri_info["service"]), uri_info["method"]) except ERROR_BASE as e: raise e except Exception as e: - raise ERROR_GRPC_CONFIGURATION(endpoint=uri_info.get('endpoint'), - service=uri_info.get('service'), - method=uri_info.get('method')) + raise ERROR_GRPC_CONFIGURATION( + endpoint=uri_info.get("endpoint"), + service=uri_info.get("service"), + method=uri_info.get("method"), + )