diff --git a/Houdini/Converters.py b/Houdini/Converters.py index 965459f..124e3f9 100644 --- a/Houdini/Converters.py +++ b/Houdini/Converters.py @@ -2,6 +2,159 @@ from abc import ABC from abc import abstractmethod import asyncio +import itertools +import inspect + + +from Houdini.Cooldown import CooldownError + + +class ChecklistError(Exception): + """Raised when a checklist fails""" + + +class _ArgumentDeserializer: + __slots__ = ['name', 'components', 'callback', 'parent', 'pass_raw', 'cooldown', + 'checklist', 'instance', 'alias', 'rest_raw', 'string_delimiter', + 'string_separator', '_signature', '_arguments', '_exception_callback', + '_exception_class'] + + def __init__(self, name, callback, **kwargs): + self.callback = callback + + self.name = callback.__name__ if name is None else name + self.cooldown = kwargs.get('cooldown') + self.checklist = kwargs.get('checklist', []) + self.rest_raw = kwargs.get('rest_raw', False) + self.string_delimiter = kwargs.get('string_delimiter', []) + self.string_separator = kwargs.get('string_separator', str()) + + self.instance = None + + self._signature = list(inspect.signature(self.callback).parameters.values()) + self._arguments = inspect.getfullargspec(self.callback) + + self._exception_callback = None + self._exception_class = Exception + + if self.rest_raw: + self._signature = self._signature[:-1] + + def _can_run(self, p): + return True if not self.checklist else all(predicate(self, p) for predicate in self.checklist) + + async def _check_cooldown(self, p): + if self.cooldown is not None: + bucket = self.cooldown.get_bucket(p) + if bucket.is_cooling: + if self.cooldown.callback is not None: + if self.instance: + await self.cooldown.callback(self.instance, p) + else: + await self.cooldown.callback(p) + else: + raise CooldownError('{} invoked listener during cooldown'.format(p)) + + def _check_list(self, p): + if not self._can_run(p): + raise ChecklistError('Could not invoke listener due to checklist failure') + + def _consume_separated_string(self, ctx): + if ctx.argument[0] in self.string_delimiter: + while not ctx.argument.endswith(ctx.argument[0]): + ctx.argument += self.string_separator + next(ctx.arguments) + ctx.argument = ctx.argument[1:-1] + + def error(self, exception_class=Exception): + def decorator(exception_callback): + self._exception_callback = exception_callback + self._exception_class = exception_class + return decorator + + async def _deserialize(self, p, data): + handler_call_arguments = [self.instance, p] if self.instance is not None else [p] + handler_call_keywords = {} + + arguments = itertools.islice(data, len(data) - len(self._arguments.kwonlyargs)) + keyword_arguments = itertools.islice(data, len(data) - len(self._arguments.kwonlyargs), len(data)) + + ctx = _ConverterContext(None, arguments, None, p) + for ctx.component in itertools.islice(self._signature, len(handler_call_arguments), len(self._signature)): + if ctx.component.annotation is ctx.component.empty and ctx.component.default is not ctx.component.empty: + handler_call_arguments.append(ctx.component.default) + elif ctx.component.kind == ctx.component.POSITIONAL_OR_KEYWORD: + ctx.argument = next(ctx.arguments) + converter = get_converter(ctx.component) + + if converter == str: + self._consume_separated_string(ctx) + + handler_call_arguments.append(await do_conversion(converter, ctx)) + elif ctx.component.kind == ctx.component.VAR_POSITIONAL: + for argument in ctx.arguments: + ctx.argument = argument + converter = get_converter(ctx.component) + + if converter == str: + self._consume_separated_string(ctx) + + handler_call_arguments.append(await do_conversion(converter, ctx)) + elif ctx.component.kind == ctx.component.KEYWORD_ONLY: + ctx.arguments = keyword_arguments + ctx.argument = next(keyword_arguments) + converter = get_converter(ctx.component) + + if converter == str: + self._consume_separated_string(ctx) + + handler_call_keywords[ctx.component.name] = await do_conversion(converter, ctx) + + if self.rest_raw: + handler_call_arguments.append(list(ctx.arguments)) + + return handler_call_arguments, handler_call_keywords + + async def __call__(self, p, data): + try: + handler_call_arguments, handler_call_keywords = await self._deserialize(p, data) + + return await self.callback(*handler_call_arguments, **handler_call_keywords) + except Exception as e: + if self._exception_callback and isinstance(e, self._exception_class): + if self.instance: + await self._exception_callback(self.instance, e) + else: + await self._exception_callback(e) + else: + raise e + + def __hash__(self): + return hash(self.__name__()) + + def __name__(self): + return "{}.{}".format(self.callback.__module__, self.callback.__name__) + + +def _listener(cls, name, **kwargs): + def decorator(callback): + if not asyncio.iscoroutinefunction(callback): + raise TypeError('All listeners must be a coroutine.') + + try: + cooldown_object = callback.__cooldown + del callback.__cooldown + except AttributeError: + cooldown_object = None + + try: + checklist = callback.__checks + del callback.__checks + except AttributeError: + checklist = [] + + listener_object = cls(name, callback, cooldown=cooldown_object, checklist=checklist, **kwargs) + return listener_object + return decorator class IConverter(ABC): @@ -212,7 +365,7 @@ def get_converter(component): async def do_conversion(converter, ctx): - if issubclass(type(converter), IConverter) and not isinstance(converter, IConverter): + if issubclass(converter, IConverter) and not isinstance(converter, IConverter): converter = converter() if isinstance(converter, IConverter): if asyncio.iscoroutinefunction(converter.convert): diff --git a/Houdini/Handlers/__init__.py b/Houdini/Handlers/__init__.py index 923e193..c8d32bc 100644 --- a/Houdini/Handlers/__init__.py +++ b/Houdini/Handlers/__init__.py @@ -1,26 +1,15 @@ import inspect import enum import os -import asyncio +import itertools from types import FunctionType -from Houdini.Converters import get_converter, do_conversion, _ConverterContext +from Houdini.Converters import _listener, _ArgumentDeserializer, get_converter, do_conversion, _ConverterContext -from Houdini.Cooldown import _Cooldown, _CooldownMapping, BucketType, CooldownError +from Houdini.Cooldown import _Cooldown, _CooldownMapping, BucketType from Houdini import Plugins -def get_relative_function_path(function_obj): - abs_function_file = inspect.getfile(function_obj) - rel_function_file = os.path.relpath(abs_function_file) - - return rel_function_file - - -class ChecklistError(Exception): - """Raised when a checklist fails""" - - class AuthorityError(Exception): """Raised when a packet is received but user has not yet authenticated""" @@ -59,159 +48,76 @@ class Priority(enum.Enum): Low = 1 -class _Listener: +class _Listener(_ArgumentDeserializer): - __slots__ = ['packet', 'components', 'handler', 'priority', - 'cooldown', 'pass_packet', 'handler_file', - 'overrides', 'pre_login', 'checklist', 'plugin'] + __slots__ = ['priority', 'packet', 'overrides'] - def __init__(self, packet, components, handler_function, **kwargs): + def __init__(self, packet, callback, **kwargs): + super().__init__(packet.id, callback, **kwargs) self.packet = packet - self.components = components - self.handler = handler_function self.priority = kwargs.get('priority', Priority.Low) self.overrides = kwargs.get('overrides', []) - self.cooldown = kwargs.get('cooldown') - self.pass_packet = kwargs.get('pass_packet', False) - self.checklist = kwargs.get('checklist', []) - - self.plugin = None if type(self.overrides) is not list: self.overrides = [self.overrides] - self.handler_file = get_relative_function_path(handler_function) - - def _can_run(self, p): - return True if not self.checklist else all(predicate(self.packet, p) for predicate in self.checklist) - - def __hash__(self): - return hash(self.__name__()) - - def __name__(self): - return "{}.{}".format(self.handler.__module__, self.handler.__name__) - - async def __call__(self, p, packet_data): - if isinstance(self.packet, XTPacket) and not self.pre_login and not p.joined_world: - await p.close() - raise AuthorityError('{} tried sending XT packet before authentication!'.format(p)) - - if self.cooldown is not None: - bucket = self.cooldown.get_bucket(p) - if bucket.is_cooling: - if self.cooldown.callback is not None: - await self.cooldown.callback(*[self.plugin, p] if self.plugin is not None else p) - else: - raise CooldownError('{} sent packet during cooldown'.format(p)) - - if not self._can_run(p): - raise ChecklistError('Could not handle packet due to checklist failure') - class _XTListener(_Listener): - __slots__ = ['pre_login', 'rest_raw', 'keywords'] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.pre_login = kwargs.get('pre_login') - self.rest_raw = kwargs.get('rest_raw', False) - - self.keywords = len(inspect.getfullargspec(self.handler).kwonlyargs) - - if self.rest_raw: - self.components = self.components[:-1] - - async def __call__(self, p, packet_data): - await super().__call__(p, packet_data) - - handler_call_arguments = [self.plugin] if self.plugin is not None else [] - handler_call_arguments += [self.packet, p] if self.pass_packet else [p] - handler_call_keywords = {} - - arguments = iter(packet_data[:-self.keywords]) - ctx = _ConverterContext(None, arguments, None, p) - for ctx.component in self.components: - if ctx.component.annotation is ctx.component.empty and ctx.component.default is not ctx.component.empty: - handler_call_arguments.append(ctx.component.default) - next(ctx.arguments) - elif ctx.component.kind == ctx.component.POSITIONAL_OR_KEYWORD: - ctx.argument = next(ctx.arguments) - converter = get_converter(ctx.component) - - handler_call_arguments.append(await do_conversion(converter, ctx)) - elif ctx.component.kind == ctx.component.VAR_POSITIONAL: - for argument in ctx.arguments: - ctx.argument = argument - converter = get_converter(ctx.component) - - handler_call_arguments.append(await do_conversion(converter, ctx)) - elif ctx.component.kind == ctx.component.KEYWORD_ONLY: - ctx.argument = packet_data[-self.keywords:][len(handler_call_keywords)] - converter = get_converter(ctx.component) - handler_call_keywords[ctx.component.name] = await do_conversion(converter, ctx) - - if self.rest_raw: - handler_call_arguments.append(list(ctx.arguments)) - return await self.handler(*handler_call_arguments, **handler_call_keywords) - elif not len(list(ctx.arguments)): - return await self.handler(*handler_call_arguments, **handler_call_keywords) - - -class _XMLListener(_Listener): __slots__ = ['pre_login'] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.pre_login = kwargs.get('pre_login') + async def __call__(self, p, packet_data): + if not self.pre_login and not p.joined_world: + await p.close() + raise AuthorityError('{} tried sending XT packet before authentication!'.format(p)) + + await super()._check_cooldown(p) + super()._check_list(p) + await super().__call__(p, packet_data) - handler_call_arguments = [self.plugin] if self.plugin is not None else [] - handler_call_arguments += [self.packet, p] if self.pass_packet else [p] - for index, component in enumerate(self.components): - if component.default is not component.empty: - handler_call_arguments.append(component.default) - elif component.kind == component.POSITIONAL_OR_KEYWORD: - converter = get_converter(component) - ctx = _ConverterContext(component, None, packet_data, p) +class _XMLListener(_Listener): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def __call__(self, p, packet_data): + await super()._check_cooldown(p) + super()._check_list(p) + + handler_call_arguments = [self.instance, p] if self.instance is not None else [p] + + ctx = _ConverterContext(None, None, packet_data, p) + for ctx.component in itertools.islice(self._signature, len(handler_call_arguments), len(self._signature)): + if ctx.component.default is not ctx.component.empty: + handler_call_arguments.append(ctx.component.default) + elif ctx.component.kind == ctx.component.POSITIONAL_OR_KEYWORD: + converter = get_converter(ctx.component) + handler_call_arguments.append(await do_conversion(converter, ctx)) - return await self.handler(*handler_call_arguments) + return await self.callback(*handler_call_arguments) + + +def get_relative_function_path(function_obj): + abs_function_file = inspect.getfile(function_obj) + rel_function_file = os.path.relpath(abs_function_file) + + return rel_function_file def handler(packet, **kwargs): - def decorator(handler_function): - if not asyncio.iscoroutinefunction(handler_function): - raise TypeError('All handlers must be a coroutine.') + if not issubclass(type(packet), _Packet): + raise TypeError('All handlers can only listen for either XMLPacket or XTPacket.') - components = list(inspect.signature(handler_function).parameters.values()) - components = components[2:] if str(components[0]) == 'self' else components[1:] - - if not issubclass(type(packet), _Packet): - raise TypeError('All handlers can only listen for either XMLPacket or XTPacket.') - - listener_class = _XTListener if isinstance(packet, XTPacket) else _XMLListener - - try: - cooldown_object = handler_function.__cooldown - del handler_function.__cooldown - except AttributeError: - cooldown_object = None - - try: - checklist = handler_function.__checks - del handler_function.__checks - except AttributeError: - checklist = [] - - listener_object = listener_class(packet, components, handler_function, - cooldown=cooldown_object, checklist=checklist, - **kwargs) - return listener_object - return decorator + listener_class = _XTListener if isinstance(packet, XTPacket) else _XMLListener + return _listener(listener_class, packet, **kwargs) def listener_exists(xt_listeners, xml_listeners, packet): @@ -227,7 +133,7 @@ def listeners_from_module(xt_listeners, xml_listeners, module): listener_objects = inspect.getmembers(module, is_listener) for listener_name, listener_object in listener_objects: if isinstance(module, Plugins.IPlugin): - listener_object.plugin = module + listener_object.instance = module listener_collection = xt_listeners if type(listener_object) == _XTListener else xml_listeners if listener_object.packet not in listener_collection: @@ -251,7 +157,8 @@ def remove_handlers_by_module(xt_listeners, xml_listeners, handler_module_path): def remove_handlers(remove_handler_items): for handler_id, handler_listeners in remove_handler_items: for handler_listener in handler_listeners: - if handler_listener.handler_file == handler_module_path: + handler_file = get_relative_function_path(handler_listener.callback) + if handler_file == handler_module_path: handler_listeners.remove(handler_listener) remove_handlers(xt_listeners.items()) remove_handlers(xml_listeners.items()) @@ -278,8 +185,8 @@ def check(predicate): def allow_once(): - def check_for_packet(packet, p): - return packet not in p.received_packets + def check_for_packet(listener, p): + return listener.packet not in p.received_packets return check(check_for_packet)