Source code for ayx_python_sdk.providers.amp_provider.sdk_tool_service_v2
# Copyright (C) 2022 Alteryx, Inc. All rights reserved.
#
# Licensed under the ALTERYX SDK AND API LICENSE AGREEMENT;
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.alteryx.com/alteryx-sdk-and-api-license-agreement
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test harness implementation of the SDK Engine service."""
import asyncio
import logging
from ayx_python_sdk.core.constants import Anchor
from ayx_python_sdk.providers.amp_provider.amp_driver import AMPDriver
from ayx_python_sdk.providers.amp_provider.resources.generated.sdk_tool_service_v2_pb2 import (
ControlOut,
RecordTransferOut,
)
from ayx_python_sdk.providers.amp_provider.resources.generated.sdk_tool_service_v2_pb2_grpc import (
SdkToolV2Servicer,
)
[docs]class SdkToolServiceV2(SdkToolV2Servicer):
"""Implementation of the SDK Engine V2 service."""
logger = logging.getLogger()
driver = AMPDriver()
init_data: asyncio.Future = asyncio.Future()
driver_guard: asyncio.Lock = asyncio.Lock()
cond_teardown = asyncio.Condition(lock=driver_guard)
record_teardown = asyncio.Condition(lock=driver_guard)
curr_driver_fn = None
ready_for_records: asyncio.Event = asyncio.Event()
record_batch_received: asyncio.Event = asyncio.Event()
[docs] async def Control(self, request_iterator, context): # type: ignore # noqa: N802
"""
Handle Control messages.
Initialize plugin on initialize ControlIn, then send any ControlOut messages
to the client as needed, while monitoring for client responses.
"""
try:
# Start read/writes and a worker to handle any callbacks for the driver
tasks = [
asyncio.create_task(
self._ctrl_read(request_iterator), name="_ctrl_read"
),
asyncio.create_task(self._ctrl_write(context), name="_ctrl_write"),
asyncio.create_task(
self._driver_callback_worker(), name="_driver_callback_worker"
),
asyncio.create_task(
self._user_callback_worker(), name="_user_callback_worker"
),
]
# Clean up tasks
async with self.cond_teardown:
self.logger.debug("Control waiting on teardown notification")
await self.cond_teardown.wait()
await asyncio.sleep(0)
complete_msg = ControlOut()
complete_msg.confirm_complete.SetInParent()
self.driver.ctrl_io.ctrl_out.put_nowait(complete_msg)
await self.driver.ctrl_io.ctrl_out.join()
self.logger.debug("Control starting teardown")
await self.driver.ctrl_io.flush()
for t in tasks:
t.cancel()
self.logger.debug("Control stream waiting for close")
except asyncio.CancelledError as e:
self.logger.error("ERROR: Client side disconnected from server.")
self.logger.error(repr(e))
async def _ctrl_read(self, request_iterator) -> None: # type: ignore
awaits_response = {"translated_message", "decrypted_password", "dcm_e_response"}
async for request in request_iterator:
payload = request.WhichOneof("payload")
if payload == "plugin_initialization_data": # TODO update this to const
try:
asyncio.create_task(
self.driver._initialize_plugin(
request, self.ready_for_records, self.record_batch_received
)
)
except Exception as e:
self.logger.error("%s", repr(e))
elif payload == "incoming_connection_complete":
conn_info = request.incoming_connection_complete
closed_anchor = Anchor(conn_info.anchor_name, conn_info.connection_name)
self.logger.debug("Pushing incoming connection complete")
if (
not self.record_batch_received.is_set()
): # OOP sends incoming_connection_complete without sending any records
self.logger.debug("empty row bug detected")
self.record_batch_received.set()
self.driver.ctrl_io.push_driver_callback(
self.driver.incoming_connection_complete_callback, closed_anchor
)
elif payload in awaits_response:
try:
if self.driver.ctrl_io.awaiting_response.get(request.msg_id):
self.driver.ctrl_io.push_callback_action(request)
else:
self.driver.ctrl_io.blocking_awaiting_response[
request.msg_id
] = getattr(request, payload)
except Exception as e:
self.logger.debug(repr(e))
elif payload == "notify_complete":
# Notify the driver to start completion
# The server assumes that the client is done sending RecordIn messages.
# Client should have sent any incoming records or otherwise at this point.
self.driver.ctrl_io.push_driver_callback(
self.driver.on_complete_callback
)
async def _ctrl_write(self, context) -> None: # type: ignore
while True:
msg = await self.driver.ctrl_io.ctrl_out.get()
try:
await context.write(msg)
except Exception as e: # catch ExecuteBatchError if any
self.logger.debug(repr(e))
self.driver.ctrl_io.ctrl_out.task_done()
async def _user_callback_worker(self) -> None:
while True:
action = await self.driver.ctrl_io.ctrl_user_callback_actions.get()
loop = asyncio.get_event_loop()
try:
fn = action["callback_fn"]
if action.get("response_msg"):
fut = loop.run_in_executor(None, fn, action["response_msg"])
else:
fut = loop.run_in_executor(None, fn)
await fut
# Could collect these in order
if fut.exception():
raise fut.exception() or BaseException()
except Exception as e:
report_str = (
f"Failed while calling {action['callback_fn'].__name__}\n {repr(e)}"
)
self.logger.error(report_str)
self.driver.provider.io.error(report_str)
finally:
self.driver.ctrl_io.ctrl_user_callback_actions.task_done()
async def _driver_callback_worker(self) -> None:
while True:
action = await self.driver.ctrl_io.ctrl_driver_actions.get()
# Make sure we clear any queued user callbacks
await asyncio.sleep(0)
await self.driver.ctrl_io.ctrl_user_callback_actions.join()
is_on_complete = action["driver_fn"] == self.driver.on_complete_callback
if is_on_complete:
self.logger.debug("got on_complete, checking if need to requeue")
self.logger.debug(
f"ready for on_complete: {self._ready_for_on_complete()}, record_batch_received.is_set(): {self.record_batch_received.is_set()}"
)
if (
not self._ready_for_on_complete()
or not self.record_batch_received.is_set()
):
self.logger.debug("requeue on_complete")
try:
await self._requeue_action(
action, self.driver.ctrl_io.ctrl_driver_actions
)
except Exception as e:
self.logger.error(e)
continue
elif self.driver.provider.environment.update_only:
async with self.record_teardown:
self.record_teardown.notify_all()
return
loop = asyncio.get_running_loop()
if action["args"]:
fut = loop.run_in_executor(None, action["driver_fn"], *action["args"])
else:
# Other plugin methods have no args
fut = loop.run_in_executor(None, action["driver_fn"])
# Handle any generated work from calling method
while fut.done() is False:
await asyncio.sleep(0.1)
action["event_cb_complete"].set()
if fut.exception():
# Driver functions still use an error handling wrapper
# So we don't have to log here, just teardown
async with self.record_teardown:
self.record_teardown.notify_all()
raise fut.exception() # type: ignore
if is_on_complete:
async with self.record_teardown:
self.record_teardown.notify_all()
return
async def _requeue_action(self, action: dict, queue: asyncio.Queue) -> None:
# Put action back on to passed queue.
await asyncio.sleep(0)
await queue.put(action)
queue.task_done()
def _ready_for_on_complete(self) -> bool:
pre_complete_actions = self.driver.ctrl_io.ctrl_driver_actions.empty()
pending_record_batches = self.driver.record_io.completed_streams.empty()
awaiting_responses = len(self.driver.ctrl_io.awaiting_response) < 1
return all([pre_complete_actions, pending_record_batches, awaiting_responses])
[docs] async def RecordTransfer(self, request_iterator, context): # type: ignore # noqa: N802
"""
Definition for gRPC RecordTransfer.
Consumes any data sent by the client, then send any pending RecordTransferOut messages.
"""
self.logger.debug("Record Transfer stream starting.")
tasks = [
asyncio.create_task(
self._record_read(request_iterator), name="_record_read"
),
asyncio.create_task(self._record_write(context), name="_record_write"),
asyncio.create_task(
self._record_driver_actions(), name="_record_driver_actions"
),
]
async with self.record_teardown:
self.logger.debug("Waiting on teardown notification")
await self.record_teardown.wait()
await asyncio.sleep(0)
self.logger.debug("Received teardown notice. Cancelling tasks")
try:
for _, anchor in self.driver.provider.outgoing_anchors.items():
if (
anchor["num_connections"] > 0 or anchor.get("metadata")
) and not self.driver.provider.environment.update_only:
rec_out_chunk_end = RecordTransferOut()
rec_out_chunk_end.close_outgoing_anchor.name = anchor["name"]
self.driver.record_io.pending_writes.put_nowait(
{
"write_type": "close_outgoing_anchor",
"message": rec_out_chunk_end,
}
)
await self.driver.record_io.flush()
for t in tasks:
t.cancel()
except Exception as e:
self.logger.debug(repr(e))
async with self.cond_teardown:
self.cond_teardown.notify_all()
self.logger.debug("Exiting RecordTransfer...")
async def _record_read(self, request_iterator) -> None: # type: ignore
"""Receive any RecordTransferIn messages from the client."""
async for req in request_iterator:
payload = req.WhichOneof("payload")
if payload == "incoming_records":
self.driver.record_io.receive_chunk(req)
async def _record_write(self, context) -> None: # type: ignore
"""Write and send RecordTransferOut.outgoing_records from write queue."""
while True:
# If user has written to buffer, send records
to_write = await self.driver.record_io.pending_writes.get()
if to_write["write_type"] == "outgoing_records":
for msg in self.driver.record_io.get_stream_msgs(to_write):
await context.write(msg)
else:
try:
anchor = self.driver.provider.outgoing_anchors[
to_write["message"].close_outgoing_anchor.name
]
if (
anchor["num_connections"] > 0
and not anchor.get("metadata", False)
) and not self.driver.provider.environment.update_only:
self.logger.debug("Wrote close anchor")
await context.write(to_write["message"])
except Exception as e:
self.logger.error("Failed during record_write: %s", repr(e))
self.driver.record_io.pending_writes.task_done()
async def _record_driver_actions(self) -> None:
"""Handle events related to receiving or sending batches."""
while True:
batch_item = await self.driver.record_io.completed_streams.get()
self.logger.debug(f"Got a msg from completed_streams queue: {batch_item}")
await self.ready_for_records.wait() # wait for plugin_init to finish before calling on_record_batch
self.driver.ctrl_io.push_driver_callback(
self.driver.record_batch_received,
batch_item["record_batch"],
batch_item["anchor"],
)
self.record_batch_received.set()
self.driver.record_io.completed_streams.task_done()