# Copyright 2021 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
# or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.
"""``PlotlyDataSet`` saves plotly objects to a JSON file and loads JSON plotly figures
into plotly.graph_objects.Figure objects.
"""
from typing import Any, Dict
import pandas as pd
import plotly.express as px
import plotly.io as pio
from plotly import graph_objects
from kedro.extras.datasets.pandas import JSONDataSet
from kedro.io.core import Version, get_filepath_str
[docs]class PlotlyDataSet(JSONDataSet):
"""``PlotlyDataSet`` saves a pandas DataFrame to a plotly JSON file.
The plotly JSON file can be saved to any underlying filesystem
supported by fsspec (e.g. local, S3, GCS).
Warning: This DataSet is not symmetric and doesn't load back
into pandas DataFrames, but into plotly.graph_objects.Figure.
Example configuration for a PlotlyDataSet in the catalog:
::
>>> bar_plot:
>>> type: plotly.PlotlyDataSet
>>> filepath: data/08_reporting/bar_plot.json
>>> plotly_args:
>>> type: bar
>>> fig:
>>> x: features
>>> y: importance
>>> orientation: 'h'
>>> layout:
>>> xaxis_title: 'x'
>>> yaxis_title: 'y'
>>> title: 'Test'
"""
DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any]
# pylint: disable=too-many-arguments
[docs] def __init__(
self,
filepath: str,
plotly_args: Dict[str, Any],
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
version: Version = None,
credentials: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
) -> None:
"""Creates a new instance of ``PlotlyDataSet`` pointing to a plotly.graph_objects.Figure
saved as a concrete JSON file on a specific filesystem.
Args:
filepath: Filepath in POSIX format to a JSON file prefixed with a protocol like `s3://`.
If prefix is not provided `file` protocol (local filesystem) will be used.
The prefix should be any protocol supported by ``fsspec``.
Note: `http(s)` doesn't support versioning.
plotly_args: Plotly configuration for generating a plotly graph object Figure
representing the plotted data.
load_args: Plotly options for loading JSON files.
Here you can find all available arguments:
https://plotly.com/python-api-reference/generated/plotly.io.from_json.html#plotly.io.from_json
All defaults are preserved.
save_args: Plotly options for saving JSON files.
Here you can find all available arguments:
https://plotly.com/python-api-reference/generated/plotly.io.write_json.html
All defaults are preserved, but "index", which is set to False.
version: If specified, should be an instance of
``kedro.io.core.Version``. If its ``load`` attribute is
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
credentials: Credentials required to get access to the underlying filesystem.
E.g. for ``GCSFileSystem`` it should look like `{'token': None}`.
fs_args: Extra arguments to pass into underlying filesystem class constructor
(e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as
to pass to the filesystem's `open` method through nested keys
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
"""
super().__init__(filepath, load_args, save_args, version, credentials, fs_args)
self._plotly_args = plotly_args
def _describe(self) -> Dict[str, Any]:
return {**super()._describe(), "plotly_args": self._plotly_args}
def _save(self, data: pd.DataFrame) -> None:
plot_data = _plotly_express_wrapper(data, self._plotly_args)
full_key_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(full_key_path, **self._fs_open_args_save) as fs_file:
plot_data.write_json(fs_file, **self._save_args)
self._invalidate_cache()
def _load(self) -> graph_objects.Figure:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
# read_json doesn't work correctly with file handler, so we have to read the file,
# decode it manually and pass to the low-level from_json instead.
return pio.from_json(str(fs_file.read(), "utf-8"), **self._load_args)
def _plotly_express_wrapper(
data: pd.DataFrame, plotly_config: Dict[str, Any]
) -> graph_objects.Figure:
"""Generates plotly graph object Figure based on the type of plotting
and config provided in the catalog.
Args:
data: pandas dataframe to generate plotly Figure for
plotly_config: plotly configurations specified in the catalog to be used
Returns:
A plotly graph_object figure representing the plotted data
"""
fig_params = plotly_config.get("fig")
plot = plotly_config.get("type")
theme = plotly_config.get("theme", "plotly")
layout_params = plotly_config.get("layout", {})
fig = getattr(px, plot)(data, **fig_params) # type: ignore
fig.update_layout(template=theme)
fig.update_layout(layout_params)
return fig