Refactor Python Modules (#29)

This commit is contained in:
Alex 2022-02-20 21:17:36 -06:00 committed by GitHub
parent c894e36855
commit f181dba964
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 67 additions and 680 deletions

View file

@ -1,27 +1,17 @@
from typing import Optional
from pydantic import BaseModel
import numpy as np
from fastapi import FastAPI
import tensorflow as tf
from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.applications.inception_v3 import preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image
IMG_SIZE = 299
PREDICTION_MODEL = InceptionV3(weights='imagenet')
from .object_detection import object_detection
from .image_classifier import image_classifier
from tf2_yolov4.anchors import YOLOV4_ANCHORS
from tf2_yolov4.model import YOLOv4
def warm_up():
img_path = f'./app/test.png'
img = image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
PREDICTION_MODEL.predict(x)
HEIGHT, WIDTH = (640, 960)
# Warm up model
warm_up()
image_classifier.warm_up()
app = FastAPI()
@ -31,21 +21,26 @@ class TagImagePayload(BaseModel):
@app.post("/tagImage")
async def post_root(payload: TagImagePayload):
imagePath = payload.thumbnail_path
image_path = payload.thumbnail_path
if imagePath[0] == '.':
imagePath = imagePath[2:]
if image_path[0] == '.':
image_path = image_path[2:]
img_path = f'./app/{imagePath}'
img = image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
return image_classifier.classify_image(image_path=image_path)
preds = PREDICTION_MODEL.predict(x)
result = decode_predictions(preds, top=3)[0]
payload = []
for _, value, _ in result:
payload.append(value)
return payload
@app.get("/")
async def test():
object_detection.run_detection()
# image = tf.io.read_file("./app/cars.jpg")
# image = tf.image.decode_image(image)
# image = tf.image.resize(image, (HEIGHT, WIDTH))
# images = tf.expand_dims(image, axis=0) / 255.0
# model = YOLOv4(
# (HEIGHT, WIDTH, 3),
# 80,
# YOLOV4_ANCHORS,
# "darknet",
# )