Source code for ayx_python_sdk.providers.amp_provider.repositories.input_record_packet_repository
# Copyright (C) 2022 Alteryx, Inc. All rights reserved.
#
# Licensed under the ALTERYX SDK AND API LICENSE AGREEMENT;
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.alteryx.com/alteryx-sdk-and-api-license-agreement
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class that saves/retrieves input record packets."""
import logging
from typing import Dict, List, TYPE_CHECKING, Tuple
from ayx_python_sdk.core.input_connection_base import InputConnectionStatus
from ayx_python_sdk.providers.amp_provider.amp_record_packet import AMPRecordPacket
from ayx_python_sdk.providers.amp_provider.builders.record_packet_builder import (
RecordPacketBuilder,
)
from ayx_python_sdk.providers.amp_provider.repositories.input_connection_repository import (
InputConnectionRepository,
)
from ayx_python_sdk.providers.amp_provider.repositories.input_metadata_repository import (
InputMetadataRepository,
)
from ayx_python_sdk.providers.amp_provider.repositories.singleton import Singleton
if TYPE_CHECKING:
from ayx_python_sdk.core.metadata import Metadata
from ayx_python_sdk.core.record_packet_base import RecordPacketBase
from ayx_python_sdk.providers.amp_provider.resources.generated.record_packet_pb2 import (
RecordPacket as ProtobufRecordPacket,
)
import pandas as pd # noqa: F401
logger = logging.getLogger(__name__)
[docs]class UnfinishedRecordPacketException(Exception):
"""Exception to be raised to indicate that a record packet isn't ready to be returned."""
pass
[docs]class EmptyRecordPacketRepositoryException(Exception):
"""Exception to be raised after the final record packet has been returned."""
pass
[docs]class InputRecordPacketRepository(metaclass=Singleton):
"""Repository that stores input record packets."""
_record_packet_builder = RecordPacketBuilder()
_input_connection_repo = InputConnectionRepository()
def __init__(self) -> None:
"""Initialize the input record packet repository."""
self._record_packet_cache: Dict[
str, Dict[str, Tuple["pd.DataFrame", "pd.DataFrame", int]]
] = {}
self._records_list: Dict[str, Dict[str, List["pd.DataFrame"]]] = {}
[docs] def push_record_packet(
self, anchor_name: str, connection_name: str, record_packet: "RecordPacketBase"
) -> None:
"""Save a record packet."""
logger.debug(
"Saving record packet (%r) for anchor %s on connection %s",
record_packet,
anchor_name,
connection_name,
)
self._records_list.setdefault(anchor_name, {})
self._records_list[anchor_name].setdefault(connection_name, [])
self._record_packet_cache.setdefault(anchor_name, {})
self._records_list[anchor_name][connection_name].append(
record_packet.to_dataframe()
)
[docs] def save_grpc_record_packet(
self,
anchor_name: str,
connection_name: str,
grpc_record_packet: "ProtobufRecordPacket",
metadata: "Metadata",
) -> None:
"""Save a record packet from its protobuffer format."""
record_packet, _, _ = self._record_packet_builder.from_protobuf(
grpc_record_packet, metadata
)
self.push_record_packet(anchor_name, connection_name, record_packet)
def _reshape_packets(
self, anchor_name: str, connection_name: str
) -> Tuple["pd.DataFrame", "pd.DataFrame", int]:
"""
Reshape packets based on number of requested rows.
Concatenate record packets from the queue into a single dataframe,
then return that dataframe and the number of record packets to remove from the queue.
Parameters
----------
anchor_name
The name of the input anchor that the metadata is associated with.
connection_name
The name of the input connection that the metadata is associated with.
Returns
-------
Tuple["pd.DataFrame", "pd.DataFrame", int]
A tuple containing:
Dataframe containing the correct number of packets,
the remainder of the original dataframe, and
number of packets to remove from internal queue
"""
import numpy as np # noqa: F811
import pandas as pd # noqa: F811
if anchor_name not in self._records_list:
raise ValueError(f"Anchor {anchor_name} not found in repository.")
if connection_name not in self._records_list[anchor_name]:
raise ValueError(
f"Connection {connection_name} not found in repository for anchor {anchor_name}."
)
if connection_name in self._record_packet_cache[anchor_name]:
return self._record_packet_cache[anchor_name][connection_name]
max_packet_size = self._input_connection_repo.get_connection(
anchor_name, connection_name
).max_packet_size
if len(self._records_list[anchor_name][connection_name]) == 0:
raise EmptyRecordPacketRepositoryException
if max_packet_size is None:
self._record_packet_cache[anchor_name][connection_name] = (
pd.concat(self._records_list[anchor_name][connection_name]),
pd.DataFrame(),
len(self._records_list[anchor_name][connection_name]),
)
return self._record_packet_cache[anchor_name][connection_name]
cumulative_lengths: np.ndarray = np.cumsum(
[len(packet) for packet in self._records_list[anchor_name][connection_name]]
)
if (
cumulative_lengths[-1] < max_packet_size
and not InputConnectionRepository().get_connection_status(
anchor_name, connection_name
)
== InputConnectionStatus.CLOSED
):
raise UnfinishedRecordPacketException
packets = [
idx
for idx, element in enumerate(cumulative_lengths)
if element > max_packet_size and idx > 0
] + [len(cumulative_lengths)]
num_packets_to_merge = packets[0]
df = pd.concat(
self._records_list[anchor_name][connection_name][:num_packets_to_merge]
)
right_size_dataframe = df.iloc[:max_packet_size]
overflow_dataframe = df.iloc[max_packet_size:]
self._record_packet_cache[anchor_name][connection_name] = (
right_size_dataframe,
overflow_dataframe,
num_packets_to_merge,
)
return self._record_packet_cache[anchor_name][connection_name]
[docs] def peek_record_packet(
self, anchor_name: str, connection_name: str
) -> "RecordPacketBase":
"""
Get the next record packet without popping from the queue.
Parameters
----------
anchor_name
The name of the input anchor that the metadata is associated with.
connection_name
The name of the input connection that the metadata is associated with.
Returns
-------
RecordPacketBase
The AMPRecordPacket at the front of the internal queue.
"""
self._record_packet_cache.setdefault(anchor_name, {})
right_size_dataframe, _, _ = self._reshape_packets(anchor_name, connection_name)
return AMPRecordPacket(
InputMetadataRepository().get_metadata(anchor_name, connection_name),
right_size_dataframe,
)
[docs] def pop_record_packet(
self, anchor_name: str, connection_name: str
) -> "RecordPacketBase":
"""
Retrieve record packet if there are enough records to meet the max packet size criteria.
Parameters
----------
anchor_name
The name of the input anchor that the metadata is associated with.
connection_name
The name of the input connection that the metadata is associated with.
Returns
-------
RecordPacketBase
The AMPRecordPacket that was popped off the internal queue.
"""
(
right_size_dataframe,
remainder_packet,
packets_to_remove,
) = self._reshape_packets(anchor_name, connection_name)
self._records_list[anchor_name][connection_name] = self._records_list[
anchor_name
][connection_name][packets_to_remove:]
if connection_name in self._record_packet_cache[anchor_name]:
del self._record_packet_cache[anchor_name][connection_name]
if not remainder_packet.empty:
self._records_list[anchor_name][connection_name].insert(0, remainder_packet)
return AMPRecordPacket(
InputMetadataRepository().get_metadata(anchor_name, connection_name),
right_size_dataframe,
)
[docs] def clear_repository(self) -> None:
"""Delete all data in the repository."""
logger.debug("Clearing InputRecordPacketRepository")
self._records_list = {}
self._record_packet_cache = {}