加快 Keras 模型预测加载时间

数据挖掘 喀拉斯 预言 api
2022-02-12 18:26:56

我正在尝试使用加载模型预测并关闭模型的 keras 创建一个预测 API。但是在 python 中的初始化时间大约是 3-5 秒,所以每个请求大约需要 5 秒来返回预测,而不管输入行的数量(预测)

有没有办法让模型保持加载,然后流式传输输入数据以获得预测。就像通过套接字或端口预加载的模型一样。

类似于 open office 文档转换器

\program\soffice.exe -accept="socket,host=127.0.0.1,port=8100;urp;" -headless -nofirststartwizard -nologo

Keras预测代码

#!/usr/bin/env python3.6
import sys
import pandas as pd
from keras.models import load_model
model = load_model('model.h5')
X = pd.read_csv(sys.argv[1]).values
prediction = model.predict(X)
pd.DataFrame(prediction).to_json(sys.argv[2])

脚本被称为

python3.6 predict.py input_scaled.csv output_scaled.json

预测时间如下

#row    time
1       4.76 secs
10      4.49 secs
50      5.37 secs
5000    5.46 secs
50000   12.7 secs
2个回答

我能够在没有烧瓶或 django 的情况下像这样工作。只需在 python 中使用默认的 http.server

from http.server import BaseHTTPRequestHandler, HTTPServer
import logging
import sys
import pandas as pd
from keras.models import load_model
from urllib.parse import urlparse
model = load_model('model.h5')

class S(BaseHTTPRequestHandler):
    def _set_response(self):
        self.send_response(200)
        self.send_header('Content-type', 'text/html')
        self.end_headers()

    def do_GET(self):
        query = urlparse(self.path).query
        params = dict(qc.split("=") for qc in query.split("&"))
        X = pd.read_csv(params["input"]).values
        prediction = model.predict(X)
        pd.DataFrame(prediction).to_json(params["output"])
        self._set_response()
        self.wfile.write("Processed".encode('utf-8'))

def run(server_class=HTTPServer, handler_class=S, port=8080):
    logging.basicConfig(level=logging.INFO)
    server_address = ('', port)
    httpd = server_class(server_address, handler_class)
    logging.info('Starting httpd...\n')
    try:
        httpd.serve_forever()
    except KeyboardInterrupt:
        pass
    httpd.server_close()
    logging.info('Stopping httpd...\n')

if __name__ == '__main__':
    from sys import argv

    if len(argv) == 2:
        run(port=int(argv[1]))
    else:
        run()

触发服务器使用

python3.6 predict_server.py 8000

类似 API

http://ip/localhost:8000/?input=predict_scaled.csv&output=prediction.json

我能想到的最简单的方法是创建一个烧瓶应用程序,该应用程序将加载一次模型并具有端点,您可以将数据作为请求发送到已加载的模型。

服务的粗略框架如下所示:

from flask import Flask, request


app = Flask(__name__)


@app.route('/')
def index():
    return ''

@app.route('/predict/', methods=['GET', 'POST'])
def predict():
    X = pd.read_csv(request.get_data()).values
    prediction = model.predict(X)
    return pd.DataFrame(prediction).to_json()


if __name__ == "__main__":
    model = load_model('model.h5')
    app.run()

然后你可以通过另一个脚本向 发出 HTTP 请求localhost:5000/predict,它会返回你的预测,然后你可以保存或做任何你想做的事情。