-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
71 lines (61 loc) · 2.77 KB
/
app.py
File metadata and controls
71 lines (61 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from flask import Flask, request, jsonify # Framework principal pour créer l'API
from flasgger import Swagger, swag_from # Pour la documentation Swagger de l'API
import numpy as np # Utilisé pour manipuler les tableaux de données d'entrée
from utils.model_utils import load_model # Pour charger le modèle pré-entraîné
from utils.validation import validate_input # Pour valider les données d'entrée
from utils.swagger_config import swagger_template # Importer le template Swagger
import json # Pour encoder et décoder des données JSON
# Initialiser Flask
app = Flask(__name__)
# Configurer Swagger pour la documentation API avec un template
swagger = Swagger(app, template=swagger_template)
# Charger le modèle au démarrage de l'application
model = load_model()
@app.before_request
def configure_swagger():
"""
Configure dynamiquement le schéma (http/https) et l'hôte pour Swagger.
Cette fonction est appelée avant chaque requête pour garantir que l'hôte et le schéma sont définis correctement.
"""
scheme = request.scheme # HTTP ou HTTPS
host = request.host # Hôte (localhost:5000 ou hôte de production)
swagger_template['schemes'] = [scheme]
swagger_template['host'] = host
app.config['SWAGGER'] = swagger_template # Mettre à jour la configuration Swagger
@app.route('/predict', methods=['POST'])
@swag_from('swagger/predict.yml') # Documentation Swagger pour l'endpoint /predict
def predict():
"""
Endpoint pour faire une prédiction avec le modèle Iris.
"""
data = request.get_json()
# Validation des données d'entrée
is_valid, error_message = validate_input(data)
if not is_valid:
return jsonify({"error": error_message}), 400
features = np.array([data['features']])
try:
# Faire une prédiction avec le modèle chargé
prediction = model.predict(features)
return jsonify({'prediction': int(prediction[0])}), 200
except ValueError as e:
# Erreur liée aux données (par exemple, mauvaise forme d'entrée)
return jsonify({"error": f"Erreur dans les données d'entrée: {str(e)}"}), 400
except Exception as e:
# Erreur interne du serveur
app.logger.error(f"Erreur interne: {str(e)}")
return jsonify({"error": "Erreur interne du serveur"}), 500
@app.route('/')
@swag_from('swagger/index.yml') # Documentation Swagger pour la page d'accueil
def index():
"""
Page d'accueil de l'API Iris.
"""
response = {"message": "Bienvenue sur l'API de prédiction Iris"}
return app.response_class(
response=json.dumps(response, ensure_ascii=False),
mimetype='application/json'
), 200
if __name__ == '__main__':
# Configurations pour l'exécution de l'application
app.run(host='0.0.0.0', port=5000)