Metadata-Version: 2.4
Name: dataproc-ml
Version: 0.1.0
Summary: A python library to ease MLOps for Dataproc customers
Author: Google LLC
License-Expression: Apache-2.0
Project-URL: Homepage, https://github.com/GoogleCloudDataproc/dataproc-ml-python
Project-URL: Documentation, https://dataproc-ml.readthedocs.io/
Project-URL: Issues, https://github.com/GoogleCloudDataproc/dataproc-ml-python/issues
Classifier: Intended Audience :: Developers
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.11
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.11
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: google-cloud-aiplatform<2.0.0,>=1.88.0
Requires-Dist: google-cloud-storage<3.0.0,>=2.19.0
Requires-Dist: pandas<3.0.0,>=2.1.4
Requires-Dist: pyarrow<17.0.0,>=16.1.0
Requires-Dist: pyspark<4.0.0,>=3.5.3
Requires-Dist: tenacity<9.0.0,>=8.5.0
Requires-Dist: tensorflow<2.20.0,>=2.17.0
Requires-Dist: torch<3.0.0,>=2.6.0
Provides-Extra: test
Requires-Dist: pytest; extra == "test"
Requires-Dist: torchvision<1.0.0,>=0.21.0; extra == "test"
Requires-Dist: pillow<12.0.0,>=11.3.0; extra == "test"
Provides-Extra: dev
Requires-Dist: pyink; extra == "dev"
Requires-Dist: pylint; extra == "dev"
Requires-Dist: build; extra == "dev"
Provides-Extra: docs
Requires-Dist: sphinx; extra == "docs"
Requires-Dist: sphinx_rtd_theme; extra == "docs"
Dynamic: license-file

<div align="center">
  <img src="https://cloud.google.com/images/social-icon-google-cloud-1200-630.png" width="120" alt="Google Cloud logo">
  <h1>Dataproc ML</h1>
</div>

[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

> **Public Preview Disclaimer**
>
> Interfaces and functionality are subject to change. It is not recommended for production-critical applications without thorough testing and understanding of the potential risks.

`Dataproc ML` is a Python library that simplifies distributed ML inference on Google Cloud Dataproc. It provides high-level handlers to run PyTorch and Vertex AI Gemini models at scale using Apache Spark, without the complexity of manual model distribution and batch processing.

## Installation

You can install the library using pip:

```bash
pip install dataproc-ml
```

## Usage Examples

Here are a couple of examples demonstrating how to use the handlers for distributed inference on a Spark DataFrame.

### Generative AI (Gemini) Model Inference

> **Note:** Using the `GenAiModelHandler` involves making API calls to Vertex AI, which will incur costs. Please review the Vertex AI Generative AI pricing.

Use Google's Gemini models to perform generative tasks on your data.
This example uses a prompt template to ask for the capital of countries listed in a Spark DataFrame.

```python
from pyspark.sql import SparkSession
from google.cloud.dataproc_ml.inference import GenAiModelHandler

spark = SparkSession.builder.getOrCreate()

# Create a sample DataFrame
data = [("USA",), ("France",), ("Japan",)]
input_df = spark.createDataFrame(data, ["country"])

# The handler will automatically use the 'country' column
result_df = (
    GenAiModelHandler()
    .prompt("What is the capital of {country}?")
    .output_col("capital_city")
    .transform(input_df)
)

result_df.show()
# +-------+----------------+
# |country|capital_city    |
# +-------+----------------+
# |USA    |Washington, D.C.|
# |France |Paris           |
# |Japan  |Tokyo           |
# +-------+----------------+
```

### PyTorch Model Inference

Run distributed inference using a pre-trained PyTorch model stored in Google Cloud Storage.
This example assumes you have a Spark DataFrame `input_df` with a column named `features` containing image tensors or other numerical data.

```python
from pyspark.sql import SparkSession
from google.cloud.dataproc_ml.inference import PyTorchModelHandler

spark = SparkSession.builder.getOrCreate()

data = [([0.1, 0.2, 0.3],), ([0.4, 0.5, 0.6],), ([0.7, 0.8, 0.9],)]
input_df = spark.createDataFrame(data, ["features"])

# Path to your saved PyTorch model in GCS
model_gcs_path = "gs://your-bucket/path/to/model.pt"

# Apply the model for inference
result_df = (
    PyTorchModelHandler()
    .model_path(model_gcs_path)
    .input_cols("features")
    .transform(input_df)
)

result_df.show()
# +------------------+--------------------+
# |          features|         predictions|
# +------------------+--------------------+
# |[0.1, 0.2, 0.3]   |[0.543, 0.457]      |
# |[0.4, 0.5, 0.6]   |[0.621, 0.379]      |
# |[0.7, 0.8, 0.9]   |[0.789, 0.211]      |
# +------------------+--------------------+
```

## Documentation

For more detailed information on the available handlers and their configurations,
please refer to our official [documentation](https://dataproc-ml.readthedocs.io/).

## Contributing

Contributions are welcome! Please see CONTRIBUTING.md for details on how to 
set up your development environment, run linters/tests, etc.

## License

This project is licensed under the Apache 2.0 License. See the LICENSE file for more details.
