|
import asyncio |
|
import inspect |
|
import json |
|
import logging |
|
from asyncio import Queue, CancelledError |
|
from sanic import Blueprint, response |
|
from sanic.request import Request |
|
from sanic.response import HTTPResponse |
|
from typing import Text, Dict, Any, Optional, Callable, Awaitable, NoReturn |
|
|
|
import rasa.utils.endpoints |
|
from rasa.core.channels.channel import ( |
|
InputChannel, |
|
CollectingOutputChannel, |
|
UserMessage, |
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class RestInput(InputChannel): |
|
"""A custom http input channel. |
|
|
|
This implementation is the basis for a custom implementation of a chat |
|
frontend. You can customize this to send messages to Rasa and |
|
retrieve responses from the assistant.""" |
|
|
|
@classmethod |
|
def name(cls) -> Text: |
|
return "rest" |
|
|
|
@staticmethod |
|
async def on_message_wrapper( |
|
on_new_message: Callable[[UserMessage], Awaitable[Any]], |
|
text: Text, |
|
queue: Queue, |
|
sender_id: Text, |
|
input_channel: Text, |
|
metadata: Optional[Dict[Text, Any]], |
|
) -> None: |
|
collector = QueueOutputChannel(queue) |
|
|
|
message = UserMessage( |
|
text, collector, sender_id, input_channel=input_channel, metadata=metadata |
|
) |
|
await on_new_message(message) |
|
|
|
await queue.put("DONE") |
|
|
|
async def _extract_sender(self, req: Request) -> Optional[Text]: |
|
return req.json.get("recipient_id", None) |
|
|
|
# noinspection PyMethodMayBeStatic |
|
def _extract_message(self, req: Request) -> Optional[Text]: |
|
return req.json.get("text", None) |
|
|
|
def _extract_input_channel(self, req: Request) -> Text: |
|
return req.json.get("input_channel") or self.name() |
|
|
|
def stream_response( |
|
self, |
|
on_new_message: Callable[[UserMessage], Awaitable[None]], |
|
text: Text, |
|
sender_id: Text, |
|
input_channel: Text, |
|
metadata: Optional[Dict[Text, Any]], |
|
) -> Callable[[Any], Awaitable[None]]: |
|
async def stream(resp: Any) -> None: |
|
q = Queue() |
|
task = asyncio.ensure_future( |
|
self.on_message_wrapper( |
|
on_new_message, text, q, sender_id, input_channel, metadata |
|
) |
|
) |
|
while True: |
|
result = await q.get() |
|
if result == "DONE": |
|
break |
|
else: |
|
await resp.write(json.dumps(result) + "\n") |
|
await task |
|
|
|
return stream |
|
|
|
def blueprint( |
|
self, on_new_message: Callable[[UserMessage], Awaitable[None]] |
|
) -> Blueprint: |
|
custom_webhook = Blueprint( |
|
"custom_webhook_{}".format(type(self).__name__), |
|
inspect.getmodule(self).__name__, |
|
) |
|
|
|
# noinspection PyUnusedLocal |
|
@custom_webhook.route("/", methods=["GET"]) |
|
async def health(request: Request) -> HTTPResponse: |
|
return response.json({"status": "ok"}) |
|
|
|
@custom_webhook.route("/webhook", methods=["POST"]) |
|
async def receive(request: Request) -> HTTPResponse: |
|
sender_id = await self._extract_sender(request) |
|
text = self._extract_message(request) |
|
should_use_stream = rasa.utils.endpoints.bool_arg( |
|
request, "stream", default=False |
|
) |
|
input_channel = self._extract_input_channel(request) |
|
metadata = self.get_metadata(request) |
|
|
|
if should_use_stream: |
|
return response.stream( |
|
self.stream_response( |
|
on_new_message, text, sender_id, input_channel, metadata |
|
), |
|
content_type="text/event-stream", |
|
) |
|
else: |
|
collector = CollectingOutputChannel() |
|
# noinspection PyBroadException |
|
try: |
|
await on_new_message( |
|
UserMessage( |
|
text, |
|
collector, |
|
sender_id, |
|
input_channel=input_channel, |
|
metadata=metadata, |
|
) |
|
) |
|
except CancelledError: |
|
logger.error( |
|
f"Message handling timed out for " f"user message '{text}'." |
|
) |
|
except Exception: |
|
logger.exception( |
|
f"An exception occured while handling " |
|
f"user message '{text}'." |
|
) |
|
return response.json(collector.messages) |
|
|
|
return custom_webhook |
|
|
|
|
|
class QueueOutputChannel(CollectingOutputChannel): |
|
"""Output channel that collects send messages in a list |
|
|
|
(doesn't send them anywhere, just collects them).""" |
|
|
|
@classmethod |
|
def name(cls) -> Text: |
|
return "queue" |
|
|
|
# noinspection PyMissingConstructor |
|
def __init__(self, message_queue: Optional[Queue] = None) -> None: |
|
super().__init__() |
|
self.messages = Queue() if not message_queue else message_queue |
|
|
|
def latest_output(self) -> NoReturn: |
|
raise NotImplementedError("A queue doesn't allow to peek at messages.") |
|
|
|
async def _persist_message(self, message: Dict[Text, Any]) -> None: |
|
await self.messages.put(message) |
Hi Dishant,
I tried your tutorial for Rasa Chatbot, Thank you for making such a simple tutorial. I am stuck with adding custom channel.
I added this python file in custom_channels folder, but I am getting the following error
and when I rename the class and the class name from credentials, it shows
sanic.router.RouteExists