kedro.extras.datasets.tensorflow.TensorFlowModelDataset¶
-
class
kedro.extras.datasets.tensorflow.
TensorFlowModelDataset
(filepath, load_args=None, save_args=None, version=None, credentials=None, fs_args=None)[source]¶ Bases:
kedro.io.core.AbstractVersionedDataSet
TensorflowModelDataset
loads and saves TensorFlow models. The underlying functionality is supported by, and passes input arguments through to, TensorFlow 2.X load_model and save_model methods.Example:
from kedro.extras.datasets.tensorflow import TensorFlowModelDataset import tensorflow as tf import numpy as np data_set = TensorFlowModelDataset("saved_model_path") model = tf.keras.Model() predictions = model.predict([...]) data_set.save(model) loaded_model = data_set.load() new_predictions = loaded_model.predict([...]) np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)
Attributes
TensorFlowModelDataset.DEFAULT_LOAD_ARGS
TensorFlowModelDataset.DEFAULT_SAVE_ARGS
Methods
TensorFlowModelDataset.__init__
(filepath[, …])Creates a new instance of TensorFlowModelDataset
.TensorFlowModelDataset.exists
()Checks whether a data set’s output already exists by calling the provided _exists() method. TensorFlowModelDataset.from_config
(name, config)Create a data set instance using the configuration provided. TensorFlowModelDataset.load
()Loads data by delegation to the provided load method. TensorFlowModelDataset.release
()Release any cached data. TensorFlowModelDataset.resolve_load_version
()Compute the version the dataset should be loaded with. TensorFlowModelDataset.resolve_save_version
()Compute the version the dataset should be saved with. TensorFlowModelDataset.save
(data)Saves data by delegation to the provided save method. -
DEFAULT_LOAD_ARGS
= {}¶
-
DEFAULT_SAVE_ARGS
= {'save_format': 'tf'}¶
-
__init__
(filepath, load_args=None, save_args=None, version=None, credentials=None, fs_args=None)[source]¶ Creates a new instance of
TensorFlowModelDataset
.Parameters: - filepath (
str
) – Filepath in POSIX format to a TensorFlow model directory prefixed with a protocol like s3://. If prefix is not provided file protocol (local filesystem) will be used. The prefix should be any protocol supported byfsspec
. Note: http(s) doesn’t support versioning. - load_args (
Optional
[Dict
[str
,Any
]]) – TensorFlow options for loading models. Here you can find all available arguments: https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model All defaults are preserved. - save_args (
Optional
[Dict
[str
,Any
]]) – TensorFlow options for saving models. Here you can find all available arguments: https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model All defaults are preserved, except for “save_format”, which is set to “tf”. - version (
Optional
[Version
]) – If specified, should be an instance ofkedro.io.core.Version
. If itsload
attribute is None, the latest version will be loaded. If itssave
attribute is None, save version will be autogenerated. - credentials (
Optional
[Dict
[str
,Any
]]) – Credentials required to get access to the underlying filesystem. E.g. forGCSFileSystem
it should look like {‘token’: None}. - fs_args (
Optional
[Dict
[str
,Any
]]) – Extra arguments to pass into underlying filesystem class constructor (e.g. {“project”: “my-project”} forGCSFileSystem
).
Return type: None
- filepath (
-
exists
()¶ Checks whether a data set’s output already exists by calling the provided _exists() method.
Return type: bool
Returns: Flag indicating whether the output already exists. Raises: DataSetError
– when underlying exists method raises error.
-
classmethod
from_config
(name, config, load_version=None, save_version=None)¶ Create a data set instance using the configuration provided.
Parameters: - name (
str
) – Data set name. - config (
Dict
[str
,Any
]) – Data set config dictionary. - load_version (
Optional
[str
]) – Version string to be used forload
operation if the data set is versioned. Has no effect on the data set if versioning was not enabled. - save_version (
Optional
[str
]) – Version string to be used forsave
operation if the data set is versioned. Has no effect on the data set if versioning was not enabled.
Return type: AbstractDataSet
Returns: An instance of an
AbstractDataSet
subclass.Raises: DataSetError
– When the function fails to create the data set from its config.- name (
-
load
()¶ Loads data by delegation to the provided load method.
Return type: Any
Returns: Data returned by the provided load method. Raises: DataSetError
– When underlying load method raises error.
-
release
()¶ Release any cached data.
Raises: DataSetError
– when underlying release method raises error.Return type: None
-
resolve_load_version
()¶ Compute the version the dataset should be loaded with.
Return type: Optional
[str
]
-
resolve_save_version
()¶ Compute the version the dataset should be saved with.
Return type: Optional
[str
]
-
save
(data)¶ Saves data by delegation to the provided save method.
Parameters: data ( Any
) – the value to be saved by provided save method.Raises: DataSetError
– when underlying save method raises error.Return type: None
-