from flask import Flask, request, jsonify
from flask_cors import CORS
import joblib
import logging
import os
from datetime import datetime
import json
import csv
import pandas as pd
from sklearn.neighbors import NearestNeighbors

app = Flask(__name__)
CORS(app)  # Permet les requêtes cross-origin depuis PrestaShop

# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Chemin du modèle
MODEL_PATH = os.path.join(os.path.dirname(__file__), 'model.pkl')
VALID_PRODUCTS_PATH = os.path.join(os.path.dirname(__file__), 'valid_products.json')
INTERACTIONS_PATH = os.path.join(os.path.dirname(__file__), 'interactions.csv')

# Chargement du modèle et pivot table
model = None
pivot = None
model_error = None

# Ensemble des IDs de produits valides (synchronisés depuis PrestaShop)
valid_product_ids = set()

try:
    if not os.path.exists(MODEL_PATH):
        raise FileNotFoundError(f"Fichier modèle non trouvé: {MODEL_PATH}")
    
    logger.info(f"Tentative de chargement du modèle depuis: {MODEL_PATH}")
    model, pivot = joblib.load(MODEL_PATH)
    logger.info("Modèle chargé avec succès")
    logger.info(f"Taille du pivot: {pivot.shape if pivot is not None else 'None'}")
except Exception as e:
    model_error = str(e)
    logger.error(f"Erreur lors du chargement du modèle: {e}")
    logger.error(f"Type d'erreur: {type(e).__name__}")
    import traceback
    logger.error(traceback.format_exc())
    model, pivot = None, None

# Charger les IDs valides depuis le fichier s'il existe
def load_valid_products():
    global valid_product_ids
    if os.path.exists(VALID_PRODUCTS_PATH):
        try:
            with open(VALID_PRODUCTS_PATH, 'r') as f:
                data = json.load(f)
                valid_product_ids = set(int(pid) for pid in data.get('product_ids', []))
                logger.info(f"Chargé {len(valid_product_ids)} IDs de produits valides depuis {VALID_PRODUCTS_PATH}")
        except Exception as e:
            logger.error(f"Erreur lors du chargement des produits valides: {e}")
            valid_product_ids = set()
    else:
        logger.warning(f"Fichier {VALID_PRODUCTS_PATH} non trouvé. Aucun ID valide chargé.")

load_valid_products()


# NOUVEAU CODE
# Fonctions utilitaires pour le tracking du temps et la mise à jour du modèle

def time_to_rating(seconds: int) -> int:
    """
    Convertit un temps passé (en secondes) en rating implicite.
    - < 5s : 1
    - 5-30s : 3
    - > 30s : 5 (capé)
    """
    if seconds < 5:
        return 1
    if seconds < 30:
        return 3
    return 5


def append_interaction(user_id: int, product_id: int, time_spent_seconds: int, rating: int):
    """
    Ajoute une interaction dans un fichier CSV simple.
    Colonnes : user_id, product_id, rating, time_spent_seconds, timestamp
    """
    write_header = not os.path.exists(INTERACTIONS_PATH)
    try:
        with open(INTERACTIONS_PATH, mode='a', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(['user_id', 'product_id', 'rating', 'time_spent_seconds', 'timestamp'])
            writer.writerow([user_id, product_id, rating, time_spent_seconds, datetime.now().isoformat()])
        logger.info(f"Interaction ajoutée: user={user_id}, product={product_id}, rating={rating}, time={time_spent_seconds}s")
    except Exception as e:
        logger.error(f"Erreur lors de l'écriture de l'interaction: {e}")


def rebuild_model_from_interactions():
    """
    Recharge le fichier d'interactions et reconstruit le pivot et le modèle KNN.
    Met à jour les variables globales `pivot` et `model`, et sauvegarde `model.pkl`.
    """
    global pivot, model, model_error

    if not os.path.exists(INTERACTIONS_PATH):
        logger.warning(f"Aucune interaction trouvée ({INTERACTIONS_PATH} inexistant). Le modèle actuel est conservé.")
        return

    try:
        df = pd.read_csv(INTERACTIONS_PATH)
        if df.empty:
            logger.warning("Fichier d'interactions vide, impossible de reconstruire le modèle.")
            return

        # Pivot user x product avec rating moyen
        pivot_new = df.pivot_table(
            index='user_id',
            columns='product_id',
            values='rating',
            aggfunc='mean'
        ).fillna(0)

        if pivot_new.empty:
            logger.warning("Pivot reconstruit vide, conservation de l'ancien modèle.")
            return

        model_new = NearestNeighbors(metric='cosine', algorithm='brute')
        model_new.fit(pivot_new)

        pivot = pivot_new
        model = model_new
        model_error = None

        # Sauvegarde du modèle pour persistance
        joblib.dump((model, pivot), MODEL_PATH)
        logger.info(f"Modèle reconstruit et sauvegardé depuis les interactions. Pivot shape: {pivot.shape}")
    except Exception as e:
        logger.error(f"Erreur lors de la reconstruction du modèle depuis les interactions: {e}")


@app.route('/', methods=['GET'])
def home():
    return jsonify({
        'service': 'Recommendation API',
        'status': 'running',
        'version': '1.0',
        'endpoints': {
            '/recommend': 'POST - Obtenir des recommandations',
            '/health': 'GET - Vérifier le statut',
            '/sync-products': 'POST - Synchroniser les IDs de produits valides'
        }
    })


@app.route('/health', methods=['GET'])
def health():
    """Endpoint de santé pour monitoring"""
    return jsonify({
        'status': 'healthy',
        'model_loaded': model is not None,
        'model_path': MODEL_PATH,
        'model_exists': os.path.exists(MODEL_PATH),
        'model_error': model_error,
        'timestamp': datetime.now().isoformat()
    })


@app.route('/recommend', methods=['POST'])
def recommend():
    """
    Recommande des produits pour un utilisateur
    Body: {"user_id": 1, "n_recommendations": 5}
    """
    if model is None or pivot is None:
        error_msg = {
            'error': 'Modèle non disponible',
            'model_path': MODEL_PATH,
            'model_exists': os.path.exists(MODEL_PATH),
            'model_error': model_error
        }
        logger.error(f"Tentative d'utilisation du modèle alors qu'il n'est pas chargé: {error_msg}")
        return jsonify(error_msg), 500

    try:
        data = request.json
        user_id = int(data.get('user_id'))
        n_recommendations = int(data.get('n_recommendations', 5))
        
        # Avertir si aucun produit n'est synchronisé
        if not valid_product_ids:
            logger.warning("ATTENTION: Aucun produit valide synchronisé. Les recommandations peuvent contenir des IDs invalides.")
        
        # Fonction pour filtrer les IDs valides
        def filter_valid_ids(product_ids):
            if not valid_product_ids:
                # Si aucun ID valide n'est défini, retourner tous les IDs (comportement par défaut)
                logger.warning("Aucun ID de produit valide défini. Retour de tous les IDs.")
                return product_ids
            # Filtrer pour ne garder que les IDs valides
            valid = [pid for pid in product_ids if int(pid) in valid_product_ids]
            return valid
        
        # Fonction pour obtenir des produits populaires valides
        def get_popular_valid_products(n):
            if not valid_product_ids:
                # Si aucun ID valide, utiliser tous les produits du modèle
                popular = pivot.sum(axis=0).nlargest(n).index.tolist()
                return [int(pid) for pid in popular]
            # Calculer la popularité uniquement pour les produits valides
            valid_cols = [col for col in pivot.columns if int(col) in valid_product_ids]
            if not valid_cols:
                logger.warning("Aucun produit valide trouvé dans le modèle. Utilisation des IDs valides directement.")
                # Si aucun produit du modèle ne correspond, retourner les premiers IDs valides disponibles
                return list(valid_product_ids)[:n]
            popular = pivot[valid_cols].sum(axis=0).nlargest(n).index.tolist()
            return [int(pid) for pid in popular]
        
        # Fonction pour obtenir des IDs valides directement (fallback)
        def get_fallback_valid_products(n, exclude_ids=None):
            """Retourne des IDs valides directement depuis la liste synchronisée"""
            if not valid_product_ids:
                return []
            exclude_set = set(exclude_ids) if exclude_ids else set()
            available = [pid for pid in valid_product_ids if pid not in exclude_set]
            return available[:n]
        
        # Vérifier si l'utilisateur existe
        if user_id not in pivot.index:
            logger.warning(f"Utilisateur {user_id} inconnu")
            # Retourner les produits les plus populaires parmi les valides
            popular_products = get_popular_valid_products(n_recommendations)
            # Si vide, utiliser les IDs valides directement
            if len(popular_products) == 0:
                logger.warning(f"Aucun produit populaire trouvé. Utilisation des IDs valides disponibles.")
                popular_products = get_fallback_valid_products(n_recommendations)
            return jsonify({
                'user_id': user_id,
                'recommendations': popular_products,
                'type': 'popular',
                'message': 'Nouvel utilisateur - recommandations populaires',
                'count': len(popular_products)
            })

        # Trouver des utilisateurs similaires
        n_neighbors = min(6, len(pivot))  # Max 6 voisins
        distances, indices = model.kneighbors(
            [pivot.loc[user_id].values], 
            n_neighbors=n_neighbors
        )
        
        # Exclure l'utilisateur lui-même (premier résultat)
        similar_users = pivot.index[indices[0][1:]].tolist()
        
        # Produits déjà vus par l'utilisateur
        user_products = set(pivot.columns[pivot.loc[user_id] > 0])
        
        # Collecter les produits des utilisateurs similaires avec scores
        product_scores = {}
        for idx, u in enumerate(similar_users):
            user_prods = pivot.loc[u]
            for prod_id, rating in user_prods[user_prods > 0].items():
                if prod_id not in user_products:
                    # Vérifier que le produit est valide (si des IDs valides sont définis)
                    if valid_product_ids and int(prod_id) not in valid_product_ids:
                        continue
                    # Score basé sur la similarité et le rating
                    similarity = 1 - distances[0][idx + 1]
                    score = similarity * rating
                    product_scores[prod_id] = product_scores.get(prod_id, 0) + score
        
        # Trier par score décroissant
        recommended = sorted(
            product_scores.items(), 
            key=lambda x: x[1], 
            reverse=True
        )[:n_recommendations]
        
        recommended_ids = [int(prod_id) for prod_id, _ in recommended]
        
        # Si aucune recommandation personnalisée n'a été trouvée, utiliser les produits populaires
        if len(recommended_ids) == 0:
            logger.info(f"Aucune recommandation personnalisée trouvée pour user {user_id}. Utilisation de produits populaires.")
            recommended_ids = get_popular_valid_products(n_recommendations)
            # Si toujours vide, utiliser les IDs valides directement
            if len(recommended_ids) == 0:
                logger.warning(f"Aucun produit populaire valide trouvé. Utilisation des IDs valides disponibles.")
                recommended_ids = get_fallback_valid_products(n_recommendations, exclude_ids=user_products)
        
        # Si on n'a pas assez de recommandations, compléter avec des produits populaires
        if len(recommended_ids) < n_recommendations:
            logger.info(f"Seulement {len(recommended_ids)} recommandations. Complétion avec produits populaires.")
            popular = get_popular_valid_products(n_recommendations * 2)
            # Exclure ceux déjà recommandés et ceux déjà vus
            popular_filtered = [pid for pid in popular if pid not in recommended_ids and pid not in user_products]
            recommended_ids.extend(popular_filtered[:n_recommendations - len(recommended_ids)])
            
            # Si toujours pas assez, utiliser les IDs valides directement
            if len(recommended_ids) < n_recommendations:
                fallback = get_fallback_valid_products(
                    n_recommendations - len(recommended_ids),
                    exclude_ids=list(recommended_ids) + list(user_products)
                )
                recommended_ids.extend(fallback)
        
        # Déterminer le type de recommandation
        rec_type = 'personalized' if len(product_scores) > 0 else 'popular'
        
        logger.info(f"Recommandations générées pour user {user_id}: {recommended_ids} (type: {rec_type})")
        
        return jsonify({
            'user_id': user_id,
            'recommendations': recommended_ids[:n_recommendations],
            'type': rec_type,
            'count': len(recommended_ids[:n_recommendations])
        })
        
    except Exception as e:
        logger.error(f"Erreur dans /recommend: {str(e)}")
        return jsonify({'error': str(e)}), 400


@app.route('/sync-products', methods=['POST'])
def sync_products():
    """
    Synchronise la liste des IDs de produits valides depuis PrestaShop
    Body: {"product_ids": [1, 2, 3, 4, 5, ...]}
    """
    try:
        data = request.json
        if 'product_ids' not in data:
            return jsonify({'error': 'product_ids manquant dans le body'}), 400
        
        global valid_product_ids
        product_ids = data['product_ids']
        
        # Convertir en entiers et créer un set
        valid_product_ids = set(int(pid) for pid in product_ids)
        
        # Sauvegarder dans un fichier pour persistance
        try:
            with open(VALID_PRODUCTS_PATH, 'w') as f:
                json.dump({'product_ids': list(valid_product_ids), 'updated_at': datetime.now().isoformat()}, f)
            logger.info(f"Synchronisé {len(valid_product_ids)} IDs de produits valides depuis PrestaShop")
        except Exception as e:
            logger.error(f"Erreur lors de la sauvegarde des produits valides: {e}")
        
        return jsonify({
            'message': 'Produits synchronisés avec succès',
            'count': len(valid_product_ids),
            'updated_at': datetime.now().isoformat()
        })
        
    except Exception as e:
        logger.error(f"Erreur dans /sync-products: {str(e)}")
        return jsonify({'error': str(e)}), 400


@app.route('/update', methods=['POST'])
def update_model():
    """
    Endpoint pour mettre à jour le modèle à partir du temps passé sur une fiche produit.
    Body attendu (JSON):
    {
        "user_id": 1,
        "product_id": 101,
        "time_spent_seconds": 25
    }

    Le temps est converti en rating implicite, stocké dans un fichier d'interactions,
    puis le modèle KNN est reconstruit à partir de l'ensemble des interactions.
    """
    try:
        data = request.json or {}
        user_id = int(data.get('user_id', 0))
        product_id = int(data.get('product_id', 0))
        time_spent_seconds = int(data.get('time_spent_seconds', 0))

        if user_id <= 0 or product_id <= 0 or time_spent_seconds <= 0:
            return jsonify({
                'error': 'Paramètres invalides. user_id, product_id et time_spent_seconds doivent être positifs.',
                'received': data
            }), 400

        rating = time_to_rating(time_spent_seconds)
        logger.info(f"Réception update: user_id={user_id}, product_id={product_id}, time_spent={time_spent_seconds}s -> rating={rating}")

        # Enregistrer l'interaction
        append_interaction(user_id, product_id, time_spent_seconds, rating)

        # Recalculer le modèle à partir de toutes les interactions
        rebuild_model_from_interactions()

        return jsonify({
            'message': 'Interaction enregistrée et modèle mis à jour',
            'user_id': user_id,
            'product_id': product_id,
            'time_spent_seconds': time_spent_seconds,
            'rating': rating
        })

    except Exception as e:
        logger.error(f"Erreur dans /update: {str(e)}")
        return jsonify({'error': str(e)}), 400


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000, debug=False)