Skip to content

refactor: refactor code for grpc compatability #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 169 additions & 80 deletions src/spaceone/core/pygrpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import types
import grpc
from google.protobuf.json_format import ParseDict
from google.protobuf.message_factory import MessageFactory
from google.protobuf.message_factory import MessageFactory, GetMessageClass
from google.protobuf.descriptor_pool import DescriptorPool
from google.protobuf.descriptor import ServiceDescriptor, MethodDescriptor
from grpc_reflection.v1alpha.proto_reflection_descriptor_database import ProtoReflectionDescriptorDatabase
from grpc_reflection.v1alpha.proto_reflection_descriptor_database import (
ProtoReflectionDescriptorDatabase,
)
from spaceone.core.error import *

_MAX_RETRIES = 2
Expand All @@ -14,24 +16,32 @@


class _ClientInterceptor(
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):

grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor,
):
def __init__(self, options: dict, channel_key: str, request_map: dict):
self._request_map = request_map
self._channel_key = channel_key
self.metadata = options.get('metadata', {})
self.metadata = options.get("metadata", {})

def _check_message(self, client_call_details, request_or_iterator, is_stream):
if client_call_details.method in self._request_map:
if is_stream:
if not isinstance(request_or_iterator, types.GeneratorType):
raise Exception("Stream method must be specified as a generator type.")
raise Exception(
"Stream method must be specified as a generator type."
)

return self._generate_message(request_or_iterator, client_call_details.method)
return self._generate_message(
request_or_iterator, client_call_details.method
)

else:
return self._make_message(request_or_iterator, client_call_details.method)
return self._make_message(
request_or_iterator, client_call_details.method
)

return request_or_iterator

Expand All @@ -50,17 +60,17 @@ def _check_error(self, response):
if isinstance(response, Exception):
details = response.details()
status_code = response.code().name
if details.startswith('ERROR_'):
details_split = details.split(':', 1)
if details.startswith("ERROR_"):
details_split = details.split(":", 1)
if len(details_split) == 2:
error_code, error_message = details_split
else:
error_code = details_split[0]
error_message = details

if status_code == 'PERMISSION_DENIED':
if status_code == "PERMISSION_DENIED":
raise ERROR_PERMISSION_DENIED()
elif status_code == 'UNAUTHENTICATED':
elif status_code == "UNAUTHENTICATED":
raise ERROR_AUTHENTICATE_FAILURE(message=error_message)
else:
e = ERROR_INTERNAL_API(message=error_message)
Expand All @@ -70,13 +80,15 @@ def _check_error(self, response):

else:
error_message = response.details()
if status_code == 'PERMISSION_DENIED':
if status_code == "PERMISSION_DENIED":
raise ERROR_PERMISSION_DENIED()
elif status_code == 'PERMISSION_DENIED':
elif status_code == "PERMISSION_DENIED":
raise ERROR_AUTHENTICATE_FAILURE(message=error_message)
elif status_code == 'UNAVAILABLE':
e = ERROR_GRPC_CONNECTION(channel=self._channel_key, message=error_message)
e.meta['channel'] = self._channel_key
elif status_code == "UNAVAILABLE":
e = ERROR_GRPC_CONNECTION(
channel=self._channel_key, message=error_message
)
e.meta["channel"] = self._channel_key
raise e
else:
e = ERROR_INTERNAL_API(message=error_message)
Expand All @@ -92,12 +104,16 @@ def _generate_response(self, response_iterator):
except Exception as e:
self._check_error(e)

def _retry_call(self, continuation, client_call_details, request_or_iterator, is_stream):
def _retry_call(
self, continuation, client_call_details, request_or_iterator, is_stream
):
retries = 0

while True:
try:
response_or_iterator = continuation(client_call_details, request_or_iterator)
response_or_iterator = continuation(
client_call_details, request_or_iterator
)

if is_stream:
response_or_iterator = self._generate_response(response_or_iterator)
Expand All @@ -107,84 +123,142 @@ def _retry_call(self, continuation, client_call_details, request_or_iterator, is
return response_or_iterator

except Exception as e:
if e.error_code == 'ERROR_GRPC_CONNECTION':
if e.error_code == "ERROR_GRPC_CONNECTION":
if retries >= _MAX_RETRIES:
channel = e.meta.get('channel')
channel = e.meta.get("channel")
if channel in _GRPC_CHANNEL:
_LOGGER.error(f'Disconnect gRPC Endpoint. (channel = {channel})')
_LOGGER.error(
f"Disconnect gRPC Endpoint. (channel = {channel})"
)
del _GRPC_CHANNEL[channel]
raise e
else:
_LOGGER.debug(f'Retry gRPC Call: reason = {e.message}, retry = {retries + 1}')
_LOGGER.debug(
f"Retry gRPC Call: reason = {e.message}, retry = {retries + 1}"
)
else:
raise e

retries += 1

def _intercept_call(self, continuation, client_call_details,
request_or_iterator, is_request_stream, is_response_stream):
new_request_or_iterator = self. _check_message(
client_call_details, request_or_iterator, is_request_stream)

return self._retry_call(continuation, client_call_details,
new_request_or_iterator, is_response_stream)
def _intercept_call(
self,
continuation,
client_call_details,
request_or_iterator,
is_request_stream,
is_response_stream,
):
new_request_or_iterator = self._check_message(
client_call_details, request_or_iterator, is_request_stream
)

return self._retry_call(
continuation,
client_call_details,
new_request_or_iterator,
is_response_stream,
)

def intercept_unary_unary(self, continuation, client_call_details, request):
return self._intercept_call(continuation, client_call_details, request, False, False)
return self._intercept_call(
continuation, client_call_details, request, False, False
)

def intercept_unary_stream(self, continuation, client_call_details, request):
return self._intercept_call(continuation, client_call_details, request, False, True)
return self._intercept_call(
continuation, client_call_details, request, False, True
)

def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
return self._intercept_call(continuation, client_call_details, request_iterator, True, False)
def intercept_stream_unary(
self, continuation, client_call_details, request_iterator
):
return self._intercept_call(
continuation, client_call_details, request_iterator, True, False
)

def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
return self._intercept_call(continuation, client_call_details, request_iterator, True, True)
def intercept_stream_stream(
self, continuation, client_call_details, request_iterator
):
return self._intercept_call(
continuation, client_call_details, request_iterator, True, True
)


class _GRPCStub(object):

def __init__(self, desc_pool: DescriptorPool, service_desc: ServiceDescriptor, channel: grpc.Channel):
def __init__(
self,
desc_pool: DescriptorPool,
service_desc: ServiceDescriptor,
channel: grpc.Channel,
):
self._desc_pool = desc_pool
for method_desc in service_desc.methods:
self._bind_grpc_method(service_desc, method_desc, channel)

def _bind_grpc_method(self, service_desc: ServiceDescriptor, method_desc: MethodDescriptor, channel: grpc.Channel):
def _bind_grpc_method(
self,
service_desc: ServiceDescriptor,
method_desc: MethodDescriptor,
channel: grpc.Channel,
):
method_name = method_desc.name
method_key = f'/{service_desc.full_name}/{method_name}'
request_desc = self._desc_pool.FindMessageTypeByName(method_desc.input_type.full_name)
request_message_desc = MessageFactory(self._desc_pool).GetPrototype(request_desc)
response_desc = self._desc_pool.FindMessageTypeByName(method_desc.output_type.full_name)
response_message_desc = MessageFactory(self._desc_pool).GetPrototype(response_desc)
method_key = f"/{service_desc.full_name}/{method_name}"
request_desc = self._desc_pool.FindMessageTypeByName(
method_desc.input_type.full_name
)
# request_message_desc = MessageFactory(self._desc_pool).GetPrototype(request_desc)
request_message_desc = GetMessageClass(request_desc)

response_desc = self._desc_pool.FindMessageTypeByName(
method_desc.output_type.full_name
)
# response_message_desc = MessageFactory(self._desc_pool).GetPrototype(response_desc)
response_message_desc = GetMessageClass(response_desc)

if method_desc.client_streaming and method_desc.server_streaming:
setattr(self, method_name, channel.stream_stream(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString
))
setattr(
self,
method_name,
channel.stream_stream(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString,
),
)
elif method_desc.client_streaming and not method_desc.server_streaming:
setattr(self, method_name, channel.stream_unary(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString
))
setattr(
self,
method_name,
channel.stream_unary(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString,
),
)
elif not method_desc.client_streaming and method_desc.server_streaming:
setattr(self, method_name, channel.unary_stream(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString
))
setattr(
self,
method_name,
channel.unary_stream(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString,
),
)
else:
setattr(self, method_name, channel.unary_unary(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString
))
setattr(
self,
method_name,
channel.unary_unary(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString,
),
)


class GRPCClient(object):

def __init__(self, channel, options, channel_key):
self._request_map = {}
self._api_resources = {}
Expand All @@ -193,7 +267,9 @@ def __init__(self, channel, options, channel_key):
self._desc_pool = DescriptorPool(self._reflection_db)
self._init_grpc_reflection()

_client_interceptor = _ClientInterceptor(options, channel_key, self._request_map)
_client_interceptor = _ClientInterceptor(
options, channel_key, self._request_map
)
_intercept_channel = grpc.intercept_channel(channel, _client_interceptor)
self._bind_grpc_stub(_intercept_channel)

Expand All @@ -206,9 +282,12 @@ def _init_grpc_reflection(self):
service_desc: ServiceDescriptor = self._desc_pool.FindServiceByName(service)
service_name = service_desc.name
for method_desc in service_desc.methods:
method_key = f'/{service}/{method_desc.name}'
request_desc = self._desc_pool.FindMessageTypeByName(method_desc.input_type.full_name)
self._request_map[method_key] = MessageFactory(self._desc_pool).GetPrototype(request_desc)
method_key = f"/{service}/{method_desc.name}"
request_desc = self._desc_pool.FindMessageTypeByName(
method_desc.input_type.full_name
)
# self._request_map[method_key] = MessageFactory(self._desc_pool).GetPrototype(request_desc)
self._request_map[method_key] = GetMessageClass(request_desc)

if service_desc.name not in self._api_resources:
self._api_resources[service_name] = []
Expand All @@ -219,7 +298,11 @@ def _bind_grpc_stub(self, intercept_channel: grpc.Channel):
for service in self._reflection_db.get_services():
service_desc: ServiceDescriptor = self._desc_pool.FindServiceByName(service)

setattr(self, service_desc.name, _GRPCStub(self._desc_pool, service_desc, intercept_channel))
setattr(
self,
service_desc.name,
_GRPCStub(self._desc_pool, service_desc, intercept_channel),
)


def _create_secure_channel(endpoint, options):
Expand All @@ -245,8 +328,8 @@ def client(endpoint=None, ssl_enabled=False, max_message_length=None, **client_o
options = []

if max_message_length:
options.append(('grpc.max_send_message_length', max_message_length))
options.append(('grpc.max_receive_message_length', max_message_length))
options.append(("grpc.max_send_message_length", max_message_length))
options.append(("grpc.max_receive_message_length", max_message_length))

if ssl_enabled:
channel = _create_secure_channel(endpoint, options)
Expand All @@ -256,12 +339,14 @@ def client(endpoint=None, ssl_enabled=False, max_message_length=None, **client_o
try:
grpc.channel_ready_future(channel).result(timeout=3)
except Exception as e:
raise ERROR_GRPC_CONNECTION(channel=endpoint, message='Channel is not ready.')
raise ERROR_GRPC_CONNECTION(
channel=endpoint, message="Channel is not ready."
)

try:
_GRPC_CHANNEL[endpoint] = GRPCClient(channel, client_opts, endpoint)
except Exception as e:
if hasattr(e, 'details'):
if hasattr(e, "details"):
raise ERROR_GRPC_CONNECTION(channel=endpoint, message=e.details())
else:
raise ERROR_GRPC_CONNECTION(channel=endpoint, message=str(e))
Expand All @@ -271,12 +356,16 @@ def client(endpoint=None, ssl_enabled=False, max_message_length=None, **client_o

def get_grpc_method(uri_info):
try:
conn = client(endpoint=uri_info['endpoint'], ssl_enabled=uri_info['ssl_enabled'])
return getattr(getattr(conn, uri_info['service']), uri_info['method'])
conn = client(
endpoint=uri_info["endpoint"], ssl_enabled=uri_info["ssl_enabled"]
)
return getattr(getattr(conn, uri_info["service"]), uri_info["method"])

except ERROR_BASE as e:
raise e
except Exception as e:
raise ERROR_GRPC_CONFIGURATION(endpoint=uri_info.get('endpoint'),
service=uri_info.get('service'),
method=uri_info.get('method'))
raise ERROR_GRPC_CONFIGURATION(
endpoint=uri_info.get("endpoint"),
service=uri_info.get("service"),
method=uri_info.get("method"),
)
Loading