#!/usr/bin/env python
import json
import os
from io import StringIO
import flask
import pandas as pd
from flask import Flask, Response

# Your imports here


def model_fn(model_dir):
    """Required model loading for Sagemaker framework"""
    # TODO Load a specific model
    return model


def predict_fn(input_data, model):
    """Predict on the input data"""
    # TODO Add any preprocessing or prediction related logic here
    predictions = model.predict(input_data)
    return predictions


app = Flask(__name__)
model = model_fn(model_dir='/opt/ml/model')


@app.route("/ping", methods=["GET"])
def ping():
    return Response(response="\n", status=200)


@app.route("/invocations", methods=["POST"])
def predict():
    """Compound prediction function for the model"""

    # Read the input data into pandas dataframe
    input_data = flask.request.data
    content_type = flask.request.content_type
    if content_type == 'text/csv':
        # Read the raw input data as CSV.
        input_data_ = StringIO(input_data.decode('utf-8'))
        X = pd.read_csv(input_data_, header=None)
    else:
        return Response(response="Unsupported content type", status=400)

    predictions = predict_fn(X, model)
    # Convert from numpy back to CSV
    out = StringIO()
    pd.DataFrame({'results':predictions}).to_csv(out, header=False, index=False)
    result = out.getvalue()
    return Response(response=result, status=200, mimetype='text/csv')


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8080)  # Same port as in Dockerfile
