Skip to content

Commit 0695951

Browse files
committed
feat : 일기 작성시 동일한 가사 추천 방지 코드 추가(20회)
1 parent ae75c96 commit 0695951

1 file changed

Lines changed: 261 additions & 11 deletions

File tree

app/diary/router.py

Lines changed: 261 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from app.embedding.models import kobert, save_diary_embedding, split_sentences, get_user_preferred_genres, \
1414
get_songs_by_genre, get_song_embeddings, calculate_similarity
1515
from app.transaction import transactional_session
16-
from typing import List
16+
from typing import List, Set
1717
import logging
1818
import json
1919
import torch
@@ -27,9 +27,9 @@
2727
logger = logging.getLogger(__name__)
2828

2929

30-
def get_recently_recommended_song_ids(session, user_id: int, limit: int = 5) -> List[int]:
30+
def get_recently_recommended_lyrics(session, user_id: int, limit: int = 20) -> Set[str]:
3131
"""
32-
최근 작성한 일기 중에서 추천된 노래 ID 리스트를 반환 (중복 제거)
32+
최근 일기에서 추천된 가사 블럭(best_lyric)들을 반환
3333
"""
3434
subquery = (
3535
session.query(Diary.id)
@@ -39,15 +39,14 @@ def get_recently_recommended_song_ids(session, user_id: int, limit: int = 5) ->
3939
.subquery()
4040
)
4141

42-
song_ids = (
43-
session.query(RecommendedSong.song_id)
42+
lyrics = (
43+
session.query(RecommendedSong.best_lyric)
4444
.filter(RecommendedSong.diary_id.in_(subquery))
4545
.distinct()
4646
.all()
4747
)
4848

49-
# 결과는 [(song_id1,), (song_id2,), ...] 형태이므로 flatten
50-
return [sid[0] for sid in song_ids]
49+
return set(lyric[0] for lyric in lyrics)
5150

5251
@router.post("/main", response_model=DiaryResponse, status_code=201,
5352
summary="일기 작성 & Top-3 유사 가사 기반 노래 추천",
@@ -203,22 +202,25 @@ async def create_diary_with_music_recommend_top3(
203202
# 이후 raw_top, top_3, recommended_songs 생성은 기존 코드 그대로 유지
204203
raw_top = heapq.nlargest(10, heap, key=lambda x: (x[0], x[1]))
205204

206-
recent_song_ids = get_recently_recommended_song_ids(session, user_id=current_user.id, limit=5)
205+
recent_lyrics = get_recently_recommended_lyrics(session, user_id=current_user.id)
207206

208207
seen_song_ids = set()
208+
seen_lyrics = set()
209209
top_3 = []
210+
210211
for sim, _, match in raw_top:
211212
song_id = match["song_id"]
213+
lyric_chunk = " ".join(match["lyric_chunk"]).strip()
212214

213215
if song_id in seen_song_ids:
214216
continue
215-
216-
if song_id in recent_song_ids:
217-
logger.info(f"최근 추천된 곡 {song_id} 제외")
217+
if lyric_chunk in recent_lyrics:
218+
logger.info(f"최근 추천된 동일 가사 블럭 제외됨: {lyric_chunk}")
218219
continue
219220

220221
top_3.append((sim, match))
221222
seen_song_ids.add(song_id)
223+
seen_lyrics.add(lyric_chunk)
222224

223225
if len(top_3) >= 3:
224226
break
@@ -477,6 +479,254 @@ async def preview_diary_with_music_recommend_top3(
477479
"sentence_emotions": sentence_emotions
478480
}
479481

482+
@router.post("/emotion-based", response_model=DiaryResponse, status_code=201,
483+
summary="일기 작성 & 감정 기반 Top-3 노래 추천",
484+
description="일기를 작성하면 KoBERT로 감정을 분석하고, 해당 감정에 맞는 장르에서 가사 유사도를 기준으로 노래를 추천합니다.")
485+
async def create_diary_with_emotion_based_recommendation(
486+
diary_request: DiaryCreateRequest,
487+
current_user: User = Depends(get_current_user),
488+
db: Session = Depends(get_db),
489+
mongodb = Depends(get_mongodb),
490+
redis = Depends(get_redis)
491+
):
492+
with transactional_session(db) as session:
493+
sentences = split_sentences(diary_request.content)
494+
if not sentences:
495+
raise HTTPException(status_code=400, detail="분석할 문장이 없습니다.")
496+
497+
sentence_confidences = []
498+
emotion_vote_counter = {}
499+
500+
logger.info("[문장별 감정 분석 시작]")
501+
for idx, sentence in enumerate(sentences):
502+
if not sentence.strip():
503+
continue
504+
505+
logger.info(f" ▶ 문장 {idx + 1}: {sentence}")
506+
507+
emotion_id, probabilities = predict_emotion(sentence)
508+
confidence = max(probabilities)
509+
510+
logger.info(f" ▶ 문장 {idx + 1} 예측 감정 ID: {emotion_id}, 확신도: {confidence:.4f}")
511+
sentence_confidences.append((sentence, emotion_id, confidence))
512+
513+
probs_tensor = torch.tensor(probabilities)
514+
topk = torch.topk(probs_tensor, k=3)
515+
516+
for i in range(3):
517+
emo_id = topk.indices[i].item()
518+
score = topk.values[i].item()
519+
520+
if score < 0.05:
521+
continue
522+
523+
if emo_id not in emotion_vote_counter:
524+
emotion_vote_counter[emo_id] = 0.0
525+
emotion_vote_counter[emo_id] += score
526+
527+
emotion_id_full = max(emotion_vote_counter.items(), key=lambda x: x[1])[0]
528+
confidence_full = emotion_vote_counter[emotion_id_full]
529+
emotion_id_db = model_index_to_db_emotion_id[emotion_id_full]
530+
531+
logger.info("[문장별 감정 통계 기반 전체 감정 분석 결과]")
532+
for emo_id, score in sorted(emotion_vote_counter.items(), key=lambda x: -x[1]):
533+
logger.info(f" ▶ 감정 ID={emo_id}, 확신도 총합={score:.4f}")
534+
535+
logger.info(f" ▶ 최종 전체 감정 ID: {emotion_id_full}, 확신도 총합: {confidence_full:.4f}")
536+
537+
top1_emotion_id = emotion_id_full
538+
filtered_sentences = [
539+
(sentence, emo_id, conf)
540+
for sentence, emo_id, conf in sentence_confidences
541+
if emo_id == top1_emotion_id
542+
]
543+
544+
if not filtered_sentences:
545+
raise HTTPException(status_code=500, detail="Top 감정에 해당하는 문장이 없습니다.")
546+
547+
best_sentence, best_emotion_id, best_confidence = max(filtered_sentences, key=lambda x: x[2])
548+
logger.info(f"[Top 감정에서 가장 강한 문장 선택] {best_sentence} (감정 ID={best_emotion_id}, 확신도={best_confidence:.4f})")
549+
550+
combined_embedding = kobert.get_embedding(best_sentence)
551+
552+
emotion_to_genres = {
553+
0: ["댄스", "랩/힙합"],
554+
1: ["R&B/Soul", "댄스"],
555+
2: ["인디음악", "R&B/Soul"],
556+
3: ["R&B/Soul", "인디음악"],
557+
4: ["록/메탈", "인디음악"],
558+
5: ["발라드", "록/메탈"],
559+
6: ["발라드", "R&B/Soul"],
560+
7: ["랩/힙합", "록/메탈"]
561+
}
562+
563+
genre_names = emotion_to_genres.get(emotion_id_full)
564+
if not genre_names:
565+
raise HTTPException(status_code=400, detail="감정에 대응되는 장르가 없습니다.")
566+
567+
songs = await get_songs_by_genre(mongodb, genre_names)
568+
if not songs:
569+
raise HTTPException(status_code=404, detail="해당 감정의 장르에 노래가 없습니다.")
570+
571+
song_id_map = {int(song["id"]): song for song in songs}
572+
song_ids = list(song_id_map.keys())
573+
cache_keys = [f"lyrics_emb:{song_id}" for song_id in song_ids]
574+
cached_values = await redis.mget(cache_keys)
575+
576+
combined_np = np.array(combined_embedding)
577+
combined_np = combined_np / (np.linalg.norm(combined_np) + 1e-8)
578+
579+
all_embeddings = []
580+
meta_infos = []
581+
582+
for song_id, cached in zip(song_ids, cached_values):
583+
try:
584+
if cached:
585+
lyrics_embedding = np.array(json.loads(cached))
586+
else:
587+
result = session.execute(
588+
text("SELECT embedding FROM songLyricsEmbedding WHERE song_id = :song_id"),
589+
{"song_id": song_id}
590+
).fetchone()
591+
if not result:
592+
continue
593+
lyrics_embedding = np.array(json.loads(result[0]))
594+
await redis.set(f"lyrics_emb:{song_id}", json.dumps(lyrics_embedding.tolist()), ex=60*60*24*30)
595+
596+
if len(lyrics_embedding.shape) != 2:
597+
continue
598+
599+
song = song_id_map[song_id]
600+
lyrics = song.get("lyrics", [])
601+
if len(lyrics) < 1 or len(lyrics_embedding) != len(lyrics):
602+
continue
603+
604+
normed = lyrics_embedding / (np.linalg.norm(lyrics_embedding, axis=1, keepdims=True) + 1e-8)
605+
all_embeddings.append(normed)
606+
607+
for idx in range(len(normed)):
608+
meta_infos.append({
609+
"song_id": song_id,
610+
"lyric": lyrics[idx],
611+
"metadata": {
612+
"song_name": song.get("song_name"),
613+
"album_image": song.get("album_image"),
614+
"artist": song.get("artist_name_basket", []),
615+
"genre": song.get("genre")
616+
}
617+
})
618+
except Exception as e:
619+
logger.error(f"노래 처리 중 오류: {e}")
620+
continue
621+
622+
if not all_embeddings:
623+
raise HTTPException(status_code=500, detail="유사도 계산을 위한 데이터가 부족합니다.")
624+
625+
E = np.vstack(all_embeddings)
626+
sims = np.dot(E, combined_np)
627+
628+
heap = []
629+
for i, similarity in enumerate(sims):
630+
heapq.heappush(heap, (
631+
similarity,
632+
i,
633+
{
634+
"song_id": meta_infos[i]["song_id"],
635+
"lyric_chunk": [meta_infos[i]["lyric"]],
636+
"similarity": float(similarity),
637+
"metadata": meta_infos[i]["metadata"]
638+
}
639+
))
640+
641+
top_3_raw = heapq.nlargest(10, heap, key=lambda x: (x[0], x[1]))
642+
643+
recent_lyrics = get_recently_recommended_lyrics(session, user_id=current_user.id)
644+
645+
seen_song_ids = set()
646+
seen_lyrics = set()
647+
top_3 = []
648+
649+
for sim, _, match in top_3_raw:
650+
song_id = match["song_id"]
651+
lyric_chunk = " ".join(match["lyric_chunk"]).strip()
652+
653+
if song_id in seen_song_ids:
654+
continue
655+
if lyric_chunk in recent_lyrics:
656+
logger.info(f"최근 추천된 동일 가사 블럭 제외됨: {lyric_chunk}")
657+
continue
658+
659+
top_3.append((sim, match))
660+
seen_song_ids.add(song_id)
661+
seen_lyrics.add(lyric_chunk)
662+
663+
if len(top_3) >= 3:
664+
break
665+
666+
if not top_3:
667+
raise HTTPException(status_code=404, detail="적합한 노래를 찾을 수 없습니다.")
668+
669+
new_diary = Diary(
670+
user_id=current_user.id,
671+
content=diary_request.content,
672+
emotiontype_id=emotion_id_db,
673+
confidence=confidence_full,
674+
best_sentence=best_sentence,
675+
created_at=datetime.utcnow()
676+
)
677+
session.add(new_diary)
678+
session.commit()
679+
session.refresh(new_diary)
680+
681+
save_diary_embedding(session, new_diary.id, combined_embedding)
682+
683+
recommended_songs = [
684+
{
685+
"song_id": match["song_id"],
686+
"song_name": match["metadata"]["song_name"],
687+
"best_lyric": " ".join(match["lyric_chunk"]),
688+
"similarity_score": round(float(sim), 4),
689+
"album_image": match["metadata"]["album_image"],
690+
"artist": match["metadata"]["artist"],
691+
"genre": match["metadata"]["genre"]
692+
}
693+
for sim, match in top_3
694+
]
695+
696+
for song_data in recommended_songs:
697+
new_song = RecommendedSong(
698+
diary_id=new_diary.id,
699+
song_id=song_data["song_id"],
700+
song_name=song_data["song_name"],
701+
artist=song_data["artist"],
702+
genre=song_data["genre"],
703+
album_image=song_data["album_image"],
704+
best_lyric=song_data["best_lyric"],
705+
similarity_score=song_data["similarity_score"]
706+
)
707+
session.add(new_song)
708+
709+
session.commit()
710+
711+
response_data = {
712+
"id": new_diary.id,
713+
"user_id": new_diary.user_id,
714+
"content": new_diary.content,
715+
"emotiontype_id": emotion_id_db,
716+
"confidence": confidence_full,
717+
"created_at": new_diary.created_at,
718+
"updated_at": new_diary.updated_at,
719+
"recommended_songs": recommended_songs,
720+
"top_emotions": [
721+
{"emotion_id": eid, "score": round(score, 4)}
722+
for eid, score in sorted(emotion_vote_counter.items(), key=lambda x: -x[1])[:3]
723+
]
724+
}
725+
726+
logger.info("추천 결과: %s", json.dumps(response_data, indent=2, ensure_ascii=False, default=str))
727+
return response_data
728+
729+
480730
@router.get("/{diary_id}", response_model=DiaryResponse,
481731
summary="일기 조회",
482732
description="특정 일기의 상세 정보를 조회합니다.")

0 commit comments

Comments
 (0)