diff --git a/flask_assistant/core.py b/flask_assistant/core.py index 90fd806..6064b65 100644 --- a/flask_assistant/core.py +++ b/flask_assistant/core.py @@ -13,6 +13,8 @@ from api_ai.api import ApiAi from io import StringIO +from injector import Injector + def find_assistant(): # Taken from Flask-ask courtesy of @voutilad """ @@ -80,6 +82,7 @@ def __init__( dev_token=None, client_token=None, client_id=None, + injector=None, ): self.app = app @@ -98,6 +101,11 @@ def __init__( self._context_funcs = {} self._func_contexts = {} + if not injector: + injector = Injector() + + self.injector = injector + self.api = ApiAi(dev_token, client_token) if app is not None: @@ -398,8 +406,6 @@ def _set_user_profile(self): self.profile = profile_payload - - def _flask_assitant_view_func(self, nlp_result=None, *args, **kwargs): if nlp_result: # pass API query result directly @@ -445,7 +451,18 @@ def _flask_assitant_view_func(self, nlp_result=None, *args, **kwargs): return "", 400 logger.info("Matched action function: {}".format(view_func.__name__)) - result = self._map_intent_to_view_func(view_func)() + + arg_names = self._func_args(view_func) + arg_values = self._map_params_to_view_args(arg_names) + dargs = {} + for (n,v) in zip(arg_names,arg_values): + if v is not None: + dargs[n] = v + + logger.info("Injector: ") + logger.info(dargs) + result = self.injector.call_with_injection(callable=view_func ,kwargs=dargs) + if result is not None: if isinstance(result, _Response): diff --git a/requirements.txt b/requirements.txt index 97eb2d7..1f9db47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ pyasn1==0.4.7 requests==2.21.0 rsa==4.0 ruamel.yaml==0.15.81 -six==1.12.0 +six urllib3==1.24.3 -werkzeug==0.15.3 +werkzeug +Injector diff --git a/tests/test_injector.py b/tests/test_injector.py new file mode 100644 index 0000000..13070f5 --- /dev/null +++ b/tests/test_injector.py @@ -0,0 +1,92 @@ +import pytest +from flask import Flask +from flask_assistant import Assistant, ask, context_manager as manager +from injector import Module, Injector, provider, singleton, inject +from tests.helpers import build_payload, get_query_response + +class AppModule(Module): + @provider + @singleton + def provide_str(self) -> str: + return 'TEST INJECTED' + + @provider + @singleton + def provide_int(self) -> int: + return 42 + +@pytest.fixture +def assist(): + app = Flask(__name__) + module = AppModule() + injector = Injector([module]) + assist = Assistant(app, project_id="test-project-id", injector=injector) + + @assist.action("simple_intent") + def simple_intent(): + speech = "Yes" + return ask(speech) + + @inject + @assist.action("simple_intent_with_inject") + def simple_intent_with_inject(speech: str): + return ask(speech) + + @inject + @assist.action("simple_intent_with_inject_and_param") + def simple_intent_with_inject_and_param(speech: str, param): + return ask(param + "." + speech) + + @inject + @assist.action("intent_with_injects_and_2_param") + def intent_with_injects_and_2_param(speech: str, p1, p2, i: int): + return ask(speech + ":" + p1 + ":" + p2 + ":" +str(i)) + + @assist.action("add_context_1") + def add_context(): + speech = "Adding context to context_out" + manager.add("context_1") + return ask(speech) + + @assist.context("context_1") + @assist.action("intent_with_context_injects_params") + @inject + def intent_with_context_injects_params(speech: str, p1, p2, i: int): + return ask("context_1:" +speech + ":" + p1 + ":" + p2 + ":" +str(i)) + + return assist + +def test_simple_intent(assist): + client = assist.app.test_client() + payload = build_payload("simple_intent") + resp = get_query_response(client, payload) + assert "Yes" in resp["fulfillmentText"] + +def test_simple_intent_with_injection(assist): + client = assist.app.test_client() + payload = build_payload("simple_intent_with_inject") + resp = get_query_response(client, payload) + assert "TEST INJECTED" in resp["fulfillmentText"] + +def test_simple_intent_with_inject_and_param(assist): + client = assist.app.test_client() + payload = build_payload("simple_intent_with_inject_and_param",params={"param": "blue"}) + resp = get_query_response(client, payload) + assert "blue.TEST INJECTED" in resp["fulfillmentText"] + +def test_intent_with_injects_and_2_params(assist): + client = assist.app.test_client() + payload = build_payload("intent_with_injects_and_2_param",params={"p1": "blue", "p2": "female"}) + resp = get_query_response(client, payload) + assert "TEST INJECTED:blue:female:42" in resp["fulfillmentText"] + +def test_intent_with_context_injects_and_params(assist): + client = assist.app.test_client() + payload = build_payload("add_context_1") + resp = get_query_response(client, payload) + payload = build_payload("intent_with_context_injects_params", contexts=resp["outputContexts"], params={"p1": "blue", "p2": "female"}) + resp = get_query_response(client, payload) + assert "context_1:TEST INJECTED:blue:female:42" in resp["fulfillmentText"] + + +