Skip to content

Commit 1f5b4cf

Browse files
committed
feat : 여러 메소드 추가
1 parent 63b960a commit 1f5b4cf

4 files changed

Lines changed: 279 additions & 119 deletions

File tree

app/diary/router.py

Lines changed: 215 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from app.statistics.models import EmotionStatistics
88
from app.user.auth import get_current_user
99
from app.diary.models import Diary, RecommendedSong
10-
from app.diary.schemas import DiaryCreateRequest, DiaryUpdateRequest, DiaryResponse, DiaryCountResponse, SongResponse
10+
from app.diary.schemas import DiaryCreateRequest, DiaryUpdateRequest, DiaryResponse, DiaryCountResponse, SongResponse, \
11+
DiaryPreviewResponse
1112
from app.user.models import User
1213
from app.embedding.models import kobert, save_diary_embedding, split_sentences, get_user_preferred_genres, \
1314
get_songs_by_genre, get_song_embeddings, calculate_similarity
@@ -25,112 +26,28 @@
2526
logging.basicConfig(level=logging.INFO)
2627
logger = logging.getLogger(__name__)
2728

28-
@router.post("", response_model=DiaryResponse, status_code=201, summary="일기 작성 & 노래 추천",
29-
description="일기를 작성하면 자동으로 임베딩을 진행하고, 사용자의 선호 장르 내에서 가장 유사한 노래를 추천합니다.")
30-
async def create_diary(
31-
diary_request: DiaryCreateRequest,
32-
current_user: User = Depends(get_current_user),
33-
db: Session = Depends(get_db),
34-
mongodb=Depends(get_mongodb)
35-
):
29+
30+
def get_recently_recommended_song_ids(session: Session, user_id: int, limit: int = 5) -> List[int]:
3631
"""
37-
1. 새로운 일기를 DB에 저장
38-
2. Kiwi를 이용해 문장 분리 후 KoBERT로 임베딩
39-
3. DiaryEmbedding 테이블에 저장
40-
4. 유저의 선호 장르 기반으로 MongoDB에서 노래 리스트 가져오기
41-
5. 가사와 일기 텍스트 임베딩 값 비교 후 가장 유사한 노래 추천
32+
최근 작성한 일기 중에서 추천된 노래 ID 리스트를 반환 (중복 제거)
4233
"""
34+
subquery = (
35+
session.query(Diary.id)
36+
.filter(Diary.user_id == user_id)
37+
.order_by(Diary.created_at.desc())
38+
.limit(limit)
39+
.subquery()
40+
)
4341

44-
with transactional_session(db) as session:
45-
sentences = split_sentences(diary_request.content)
46-
logger.info(f"[일기 문장 분리] - 원본: {diary_request.content}")
47-
for idx, sentence in enumerate(sentences):
48-
logger.info(f" ▶ 문장 {idx + 1}: {sentence}")
49-
50-
embeddings = [kobert.get_embedding(sentence) for sentence in sentences if sentence.strip()]
51-
if not embeddings:
52-
logger.warning("KoBERT 임베딩 결과가 없음")
53-
return {"message": "임베딩할 문장이 없습니다."}
54-
55-
logger.info(f"[KoBERT 임베딩 완료] - {len(embeddings)}개 문장 처리 완료")
56-
57-
# 2) 유저 선호 장르 가져오기
58-
user_id = current_user.id
59-
genre_names = get_user_preferred_genres(session, user_id)
60-
if not genre_names:
61-
logger.warning(f"유저 {user_id}의 선호 장르가 설정되지 않음")
62-
return {"message": "유저의 선호 장르가 설정되지 않았습니다."}
63-
64-
logger.info(f"🎵 [유저 선호 장르] - {genre_names}")
65-
66-
# 3) MongoDB에서 해당 장르의 노래 가져오기
67-
songs = await get_songs_by_genre(mongodb, genre_names)
68-
if not songs:
69-
logger.warning("해당 장르에 노래가 없음")
70-
return {"message": "해당 장르에 노래가 없습니다."}
71-
72-
song_ids = [song["id"] for song in songs]
73-
logger.info(f"🎼 [가져온 노래 개수] - {len(songs)}")
74-
75-
# 4) 노래 가사 임베딩 불러오기 및 유사도 계산
76-
song_embeddings = get_song_embeddings(session, song_ids)
77-
best_match = calculate_similarity(embeddings[0], song_embeddings) # 첫 번째 문장만 비교
78-
79-
if not best_match:
80-
logger.warning("유사한 가사를 찾을 수 없음")
81-
return {"message": "유사한 가사를 찾을 수 없습니다."}
82-
83-
song_id, best_idx, similarity_score = best_match
84-
matching_song = next((song for song in songs if song["id"] == str(song_id)), None)
85-
86-
if matching_song is None:
87-
logger.error(f"추천된 song_id {song_id}가 MongoDB에서 찾을 수 없음")
88-
return {"message": "추천된 노래를 찾을 수 없습니다."}
89-
90-
# best_idx가 가사 범위를 벗어나지 않는지 확인
91-
if best_idx >= len(matching_song["lyrics"]):
92-
logger.error(f"best_idx {best_idx}가 가사 범위를 초과함 (가사 개수: {len(matching_song['lyrics'])})")
93-
return {"message": "유사한 가사를 찾을 수 없습니다."}
94-
95-
start = max(0, best_idx - 1)
96-
end = min(len(matching_song["lyrics"]), best_idx + 2)
97-
98-
context_lyrics = matching_song["lyrics"][start:end]
99-
best_lyric = " ".join(context_lyrics)
100-
101-
# 5) 모든 과정 완료 후 일기 저장 (트랜잭션 보장)
102-
new_diary = Diary(
103-
user_id=current_user.id,
104-
content=diary_request.content
105-
)
106-
session.add(new_diary)
107-
session.commit()
108-
session.refresh(new_diary)
109-
110-
logger.info(f"[📖 일기 저장 완료] - {new_diary.content}")
111-
112-
save_diary_embedding(session, new_diary.id, embeddings)
113-
114-
response_data = {
115-
"id": new_diary.id,
116-
"user_id": new_diary.user_id,
117-
"content": new_diary.content,
118-
"created_at": new_diary.created_at,
119-
"updated_at": new_diary.updated_at,
120-
"recommended_song": {
121-
"song_id": song_id,
122-
"song_name": matching_song.get("song_name", "제목 없음"),
123-
"best_lyric": best_lyric,
124-
"similarity_score": round(float(similarity_score), 4),
125-
"album_image": matching_song.get("album_image", "이미지 없음"),
126-
"artist": matching_song.get("artist_name_basket", ["아티스트 없음"]),
127-
"genre": matching_song.get("genre", "장르 없음")
128-
}
129-
}
130-
131-
logger.info(f" [응답 데이터] - {json.dumps(response_data, ensure_ascii=False, indent=4, default=str)}")
42+
song_ids = (
43+
session.query(RecommendedSong.song_id)
44+
.filter(RecommendedSong.diary_id.in_(subquery))
45+
.distinct()
46+
.all()
47+
)
13248

133-
return response_data
49+
# 결과는 [(song_id1,), (song_id2,), ...] 형태이므로 flatten
50+
return [sid[0] for sid in song_ids]
13451

13552
@router.post("/main", response_model=DiaryResponse, status_code=201,
13653
summary="일기 작성 & Top-3 유사 가사 기반 노래 추천",
@@ -191,9 +108,19 @@ async def create_diary_with_music_recommend_top3(
191108

192109
logger.info(f" ▶ 최종 전체 감정 ID: {emotion_id_full}, 확신도 총합: {confidence_full:.4f}")
193110

194-
# 4) 가장 감정이 강한 문장 선택
195-
best_sentence, best_emotion_id, best_confidence = max(sentence_confidences, key=lambda x: x[2])
196-
logger.info(f"[감정이 가장 강한 문장 선택] {best_sentence} (감정 ID={best_emotion_id}, 확신도={best_confidence:.4f})")
111+
# 4) Top-1 감정과 일치하는 문장 중 가장 확신도 높은 문장 선택
112+
top1_emotion_id = emotion_id_full # 모델 기준 감정 ID
113+
filtered_sentences = [
114+
(sentence, emo_id, conf)
115+
for sentence, emo_id, conf in sentence_confidences
116+
if emo_id == top1_emotion_id
117+
]
118+
119+
if not filtered_sentences:
120+
raise HTTPException(status_code=500, detail="Top 감정에 해당하는 문장이 없습니다.")
121+
122+
best_sentence, best_emotion_id, best_confidence = max(filtered_sentences, key=lambda x: x[2])
123+
logger.info(f"[Top 감정에서 가장 강한 문장 선택] {best_sentence} (감정 ID={best_emotion_id}, 확신도={best_confidence:.4f})")
197124

198125
# 5) best_sentence를 KoBERT 임베딩
199126
combined_embedding = kobert.get_embedding(best_sentence)
@@ -276,12 +203,23 @@ async def create_diary_with_music_recommend_top3(
276203
# 이후 raw_top, top_3, recommended_songs 생성은 기존 코드 그대로 유지
277204
raw_top = heapq.nlargest(10, heap, key=lambda x: (x[0], x[1]))
278205

206+
recent_song_ids = get_recently_recommended_song_ids(user_id=current_user.id, limit=5)
207+
279208
seen_song_ids = set()
280209
top_3 = []
281210
for sim, _, match in raw_top:
282-
if match["song_id"] not in seen_song_ids:
283-
top_3.append((sim, match))
284-
seen_song_ids.add(match["song_id"])
211+
song_id = match["song_id"]
212+
213+
if song_id in seen_song_ids:
214+
continue
215+
216+
if song_id in recent_song_ids:
217+
logger.info(f"최근 추천된 곡 {song_id} 제외")
218+
continue
219+
220+
top_3.append((sim, match))
221+
seen_song_ids.add(song_id)
222+
285223
if len(top_3) >= 3:
286224
break
287225

@@ -370,6 +308,174 @@ async def create_diary_with_music_recommend_top3(
370308
logger.info("추천 결과: %s", json.dumps(response_data, indent=2, ensure_ascii=False, default=str))
371309
return response_data
372310

311+
@router.post("/preview", summary="일기 감정 분석 + 추천 미리보기", response_model=DiaryPreviewResponse)
312+
async def preview_diary_with_music_recommend_top3(
313+
diary_request: DiaryCreateRequest,
314+
current_user: User = Depends(get_current_user),
315+
db: Session = Depends(get_db),
316+
mongodb = Depends(get_mongodb),
317+
redis = Depends(get_redis)
318+
):
319+
sentences = split_sentences(diary_request.content)
320+
if not sentences:
321+
raise HTTPException(status_code=400, detail="분석할 문장이 없습니다.")
322+
323+
sentence_emotions = []
324+
sentence_confidences = []
325+
emotion_vote_counter = {}
326+
327+
for sentence in sentences:
328+
emotion_id, probabilities = predict_emotion(sentence)
329+
confidence = max(probabilities)
330+
331+
topk = torch.topk(torch.tensor(probabilities), k=3)
332+
top3 = [
333+
{"emotion_id": topk.indices[i].item(), "score": round(topk.values[i].item(), 4)}
334+
for i in range(3) if topk.values[i].item() >= 0.01
335+
]
336+
337+
sentence_confidences.append((sentence, emotion_id, confidence))
338+
sentence_emotions.append({
339+
"sentence": sentence,
340+
"predicted_emotion_id": emotion_id,
341+
"confidence": round(confidence, 4),
342+
"top3": top3
343+
})
344+
345+
for i in range(3):
346+
emo_id = topk.indices[i].item()
347+
score = topk.values[i].item()
348+
if score < 0.05:
349+
continue
350+
emotion_vote_counter[emo_id] = emotion_vote_counter.get(emo_id, 0) + score
351+
352+
if not emotion_vote_counter:
353+
raise HTTPException(status_code=500, detail="감정 분석 실패")
354+
355+
top1_emotion_id = max(emotion_vote_counter.items(), key=lambda x: x[1])[0]
356+
confidence_full = emotion_vote_counter[top1_emotion_id]
357+
emotion_id_db = model_index_to_db_emotion_id[top1_emotion_id]
358+
359+
# Top 감정 기준 가장 강한 문장
360+
filtered_sentences = [
361+
(s, eid, c) for (s, eid, c) in sentence_confidences if eid == top1_emotion_id
362+
]
363+
if not filtered_sentences:
364+
raise HTTPException(status_code=500, detail="Top 감정 문장 없음")
365+
best_sentence, best_emotion_id, best_confidence = max(filtered_sentences, key=lambda x: x[2])
366+
367+
combined_embedding = kobert.get_embedding(best_sentence)
368+
369+
genre_names = get_user_preferred_genres(db, current_user.id)
370+
if not genre_names:
371+
raise HTTPException(status_code=400, detail="선호 장르가 설정되지 않았습니다.")
372+
373+
songs = await get_songs_by_genre(mongodb, genre_names)
374+
if not songs:
375+
raise HTTPException(status_code=404, detail="해당 장르에 노래가 없습니다.")
376+
377+
heap = []
378+
counter = 0
379+
song_id_map = {int(song["id"]): song for song in songs}
380+
song_ids = list(song_id_map.keys())
381+
cache_keys = [f"lyrics_emb:{song_id}" for song_id in song_ids]
382+
cached_values = await redis.mget(cache_keys)
383+
384+
combined_np = np.array(combined_embedding)
385+
for song_id, cached in zip(song_ids, cached_values):
386+
try:
387+
if cached:
388+
lyrics_embedding = np.array(json.loads(cached))
389+
else:
390+
result = db.execute(
391+
text("SELECT embedding FROM songLyricsEmbedding WHERE song_id = :song_id"),
392+
{"song_id": song_id}
393+
).fetchone()
394+
if not result:
395+
continue
396+
lyrics_embedding = np.array(json.loads(result[0]))
397+
await redis.set(f"lyrics_emb:{song_id}", json.dumps(lyrics_embedding.tolist()), ex=60*60*24*30)
398+
399+
if len(lyrics_embedding.shape) != 2:
400+
continue
401+
402+
song = song_id_map[song_id]
403+
lyrics = song.get("lyrics", [])
404+
if len(lyrics) < 1 or len(lyrics_embedding) != len(lyrics):
405+
continue
406+
407+
dot = np.dot(lyrics_embedding, combined_np)
408+
norm_block = np.linalg.norm(lyrics_embedding, axis=1)
409+
norm_query = np.linalg.norm(combined_np)
410+
similarities = dot / (norm_block * norm_query + 1e-8)
411+
412+
for idx, similarity in enumerate(similarities):
413+
heapq.heappush(heap, (
414+
similarity,
415+
counter,
416+
{
417+
"song_id": song_id,
418+
"lyric_chunk": [lyrics[idx]],
419+
"similarity": similarity,
420+
"metadata": {
421+
"song_name": song.get("song_name"),
422+
"album_image": song.get("album_image"),
423+
"artist": song.get("artist_name_basket", []),
424+
"genre": song.get("genre")
425+
}
426+
}
427+
))
428+
counter += 1
429+
except Exception as e:
430+
logger.error(f"[preview] 노래 유사도 처리 오류: {e}")
431+
continue
432+
433+
raw_top = heapq.nlargest(10, heap, key=lambda x: (x[0], x[1]))
434+
seen_song_ids = set()
435+
top_3 = []
436+
for sim, _, match in raw_top:
437+
if match["song_id"] not in seen_song_ids:
438+
top_3.append((sim, match))
439+
seen_song_ids.add(match["song_id"])
440+
if len(top_3) >= 3:
441+
break
442+
443+
if not top_3:
444+
raise HTTPException(status_code=404, detail="적합한 노래를 찾을 수 없습니다.")
445+
446+
recommended_songs = [
447+
{
448+
"song_id": match["song_id"],
449+
"song_name": match["metadata"]["song_name"],
450+
"best_lyric": " ".join(match["lyric_chunk"]),
451+
"similarity_score": round(float(sim), 4),
452+
"album_image": match["metadata"]["album_image"],
453+
"artist": match["metadata"]["artist"],
454+
"genre": match["metadata"]["genre"]
455+
}
456+
for sim, match in top_3
457+
]
458+
459+
return {
460+
"id": -1,
461+
"user_id": current_user.id,
462+
"content": diary_request.content,
463+
"emotiontype_id": emotion_id_db,
464+
"confidence": confidence_full,
465+
"created_at": datetime.utcnow(),
466+
"updated_at": datetime.utcnow(),
467+
"recommended_songs": recommended_songs,
468+
"top_emotions": [
469+
{"emotion_id": emo_id, "score": round(score, 4)}
470+
for emo_id, score in sorted(emotion_vote_counter.items(), key=lambda x: -x[1])[:3]
471+
],
472+
"best_sentence": {
473+
"sentence": best_sentence,
474+
"predicted_emotion_id": best_emotion_id,
475+
"confidence": round(best_confidence, 4)
476+
},
477+
"sentence_emotions": sentence_emotions
478+
}
373479

374480
@router.get("/{diary_id}", response_model=DiaryResponse,
375481
summary="일기 조회",

0 commit comments

Comments
 (0)