Spaces:
Running
Running
File size: 23,377 Bytes
2ca6d4c 326f795 d2ea144 b72a629 26d2a32 b72a629 7eb4c01 a58391a b72a629 afa43df 26d2a32 326f795 26d2a32 b72a629 326f795 b72a629 ef1a1de b72a629 d2ea144 b72a629 326f795 b72a629 326f795 b72a629 a58391a 326f795 a58391a c793402 a58391a b72a629 c793402 ef1a1de afa43df ef1a1de b72a629 afa43df b72a629 c793402 ef1a1de 326f795 b72a629 326f795 b72a629 326f795 b72a629 326f795 b72a629 afa43df c793402 b72a629 afa43df 326f795 b72a629 afa43df 5c789bb b72a629 52eae6a b72a629 d2ea144 326f795 d2ea144 326f795 d2ea144 aeb8fbf d2ea144 326f795 d2ea144 52eae6a d2ea144 326f795 d2ea144 52eae6a 326f795 52eae6a d2ea144 52eae6a 326f795 52eae6a 326f795 52eae6a 326f795 52eae6a 326f795 52eae6a 326f795 52eae6a d2ea144 326f795 d2ea144 326f795 d2ea144 326f795 d2ea144 326f795 d2ea144 b72a629 7eb4c01 b72a629 7eb4c01 b72a629 326f795 7eb4c01 b72a629 7eb4c01 326f795 7eb4c01 b72a629 7eb4c01 326f795 7eb4c01 326f795 7eb4c01 b72a629 f7ddd83 b72a629 c793402 b72a629 e4db74d b72a629 afa43df b72a629 326f795 b72a629 326f795 5285b6f a58391a 326f795 a58391a afa43df 326f795 5285b6f 326f795 b72a629 326f795 5285b6f a58391a 326f795 a58391a afa43df a58391a b72a629 326f795 5777af7 2000df9 326f795 5777af7 326f795 b72a629 5777af7 b72a629 326f795 b72a629 326f795 5777af7 326f795 5777af7 326f795 5777af7 326f795 5777af7 326f795 5777af7 326f795 5777af7 3ba5ce5 5777af7 b72a629 326f795 d2ea144 326f795 d2ea144 326f795 d2ea144 326f795 d2ea144 326f795 d2ea144 326f795 d2ea144 326f795 d2ea144 326f795 d2ea144 326f795 d2ea144 7eb4c01 326f795 7eb4c01 326f795 7eb4c01 326f795 7eb4c01 b72a629 4dfd818 |
|
import os
#os.environ["PYDANTIC_V1_STYLE"] = "1"
#os.environ["PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS"] = "1"
# --------------------------------------------------------------------------
from flask import Flask, render_template, jsonify, request, Response
from flask_socketio import SocketIO, emit
import uuid
import threading
import sqlite3
import gc
import time
import re
import traceback
import requests # API 호출을 위해 필요
from typing import Optional, Tuple, Any, Dict, List
# --- Together AI SDK ---
from together import Together
# --- eventlet monkey patch (Gunicorn + SocketIO 필수!) ---
import eventlet
eventlet.monkey_patch()
# --- Flask & SocketIO 설정 ---
app = Flask(__name__)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
import logging
# 로거 설정: 레벨을 INFO로 설정하고, 포맷을 지정합니다.
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# --- 외부 모듈 임포트 ---
# [수정됨] v02 파일명에 맞춰 임포트 (파일명이 reg_embedding_system_v02.py라면 아래와 같이 수정)
# 여기서는 편의상 reg_embedding_system으로 사용하되 내용은 v02라고 가정합니다.
import reg_embedding_system_v02 as reg_embedding_system
import leximind_prompts
# --- 전역 변수 ---
connected_clients = 0
search_document_number = 30
Filtered_search = False
filters = {"regulation": []} # [수정됨] 기본 필터 키 변경
# --- 경로 설정 ---
current_dir = os.path.dirname(os.path.abspath(__file__))
ResultFile_FolderAddress = os.path.join(current_dir, 'result.txt')
# --- RAG 데이터 경로 ---
# NOTE: Hugging Face Spaces에서 데이터가 /app/data에 있는지 확인해야 합니다.
region_paths = {
"국내": "/app/data/KMVSS_RAG",
"북미": "/app/data/FMVSS_RAG",
"유럽": "/app/data/EUR_RAG"
}
# --- 프롬프트 ---
lexi_prompts = leximind_prompts.PromptLibrary()
# 세션별 요청 추적을 위한 딕셔너리
active_sessions = {}
# --- RAG 객체 ---
region_rag_objects = {}
# --- Together AI 설정 ---
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
if not TOGETHER_API_KEY:
# 로컬 테스트용 예외 처리 등을 위해 raise 대신 경고 로그만 남길 수도 있음
logger.warning("TOGETHER_API_KEY가 설정되지 않았습니다.")
try:
client = Together(api_key=TOGETHER_API_KEY)
except NameError:
client = Together()
except Exception as e:
logger.warning(f"Together Client 초기화 실패 (API 키 확인 필요): {e}")
client = None
rag_connection_status_info = ""
# --- RAG 로딩 ---
def load_rag_objects():
global region_rag_objects
global rag_connection_status_info
logger.info(">>> [RAG_LOADER] RAG 로딩 스레드 시작 <<<")
for region, path in region_paths.items():
if not os.path.exists(path):
msg = f"[{region}] 경로 없음: {path}"
socketio.emit('message', {'message': msg})
logger.info(msg)
continue
try:
socketio.emit('message', {'message': f"[{region}] RAG 로딩 중..."})
rag_connection_status_info = f"[{region}] RAG 로딩 중..."
# [수정됨] load_embedding_from_faiss 반환값 변경 (Ensemble -> BM25)
bm25_retriever, vectorstore, sqlite_conn = reg_embedding_system.load_embedding_from_faiss(path)
sqlite_conn.close()
db_path = os.path.join(path, "metadata_mapping.db")
new_conn = sqlite3.connect(db_path, check_same_thread=False)
# [수정됨] 딕셔너리 키 변경 (ensemble_retriever -> bm25_retriever)
region_rag_objects[region] = {
"bm25_retriever": bm25_retriever,
"vectorstore": vectorstore,
"sqlite_conn": new_conn
}
socketio.emit('message', {'message': f"[{region}] 로딩 완료"})
logger.info(f"[{region}] RAG 로딩 완료")
rag_connection_status_info = f"[{region}] RAG 로딩 완료"
except Exception as e:
error_msg = f"[{region}] 로딩 실패: {str(e)}"
logger.info(error_msg)
traceback.print_exc()
socketio.emit('message', {'message': error_msg})
socketio.emit('message', {'message': "Ready to Search"})
logger.info("Ready to Search")
rag_connection_status_info = "Ready to Search"
# --- 웹 ---
@app.route('/')
def index():
return render_template('chat_v03.html')
# 전역 변수에 기본값 추가
Search_each_all_mode = True
@socketio.on('search_query')
def handle_search_query(data):
global Filtered_search
global filters
global Search_each_all_mode
session_id = str(uuid.uuid4())
active_sessions[session_id] = True
emit('search_started', {'session_id': session_id})
try:
Search_each_all_mode = data.get('searchEachMode', True)
query = data.get('query', '')
regions = data.get('regions', [])
selected_regulations = data.get('selectedRegulations', [])
emit('search_status', {'status': 'processing', 'message': '검색 요청을 처리하는 중입니다...'})
# [수정됨] 초기 필터 구조 변경 (새로운 DB 스키마 반영)
filters = {
"regulation": [], # 구 regulation_part
"section": [], # 구 regulation_section
"chapter": [], # 구 chapter_section
"standard": [] # 구 jo
}
emit('search_status', {'status': 'translating', 'message': '질문에 대해 생각 중입니다...'})
if session_id not in active_sessions:
return
Translated_query = Gemma3_AI_Translate(query)
emit('search_status', {'status': 'translated', 'message': f'번역 완료: {Translated_query}'})
if selected_regulations:
Filtered_search = True
cont_selected_num = 0
output_path = os.path.join(current_dir, "merged_ai_messages.txt")
if os.path.exists(output_path):
os.remove(output_path)
# 통합 검색 모드 - 타입별로 그룹화
grouped_regulations = group_regulations_by_type(selected_regulations)
emit('search_status', {'status': 'searching', 'message': f'선택된 {len(selected_regulations)}개 법규를 타입별로 통합하여 검색 중...'})
# 타입별로 필터 생성
combined_filters = create_combined_filters(grouped_regulations)
combined_cleaned_filter = {k: v for k, v in combined_filters.items() if v}
if Search_each_all_mode:
# 각각 검색 모드
emit('search_status', {'status': 'searching', 'message': f'선택된 {len(combined_cleaned_filter)}개 법규를 각각 검색 중...'})
total_search_num = sum(len(v) for v in combined_cleaned_filter.values())
i = 0
for RegType, RegNames in combined_cleaned_filter.items():
if RegNames:
for RegName in RegNames:
i = i + 1
if session_id not in active_sessions:
emit('search_cancelled', {'message': '검색이 취소되었습니다.'})
return
emit('search_status', {
'status': 'searching_regulation',
'message': f'법규 {i}/{len(combined_cleaned_filter)}: {RegName} 검색 중...',
'progress': (i / len(combined_cleaned_filter)) * 100
})
# 법규 타입별 필터 생성
current_filters = create_filter_by_type(RegType, RegName)
# [수정됨] failsafe_mode 인자 제거 (v02 함수 정의에 없음)
Rag_Results = search_DB_from_multiple_regions(Translated_query, regions, region_rag_objects, current_filters)
if Rag_Results:
if session_id not in active_sessions: return
emit('search_status', {
'status': 'ai_processing',
'message': f'AI가 {RegName}에 대한 답변을 생성 중...'
})
AImessage = RegAI(query, Rag_Results, ResultFile_FolderAddress)
if session_id not in active_sessions: return
emit('regulation_result', {
'regulation_title': f"[{RegName}]",
'regulation_index': i,
'total_regulations': total_search_num,
'result': AImessage
})
if isinstance(AImessage, str) and AImessage.strip():
with open(output_path, "a", encoding="utf-8") as f:
cont_selected_num += 1
from datetime import datetime
stamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
f.write(f"\n--- [{stamp}] message #{cont_selected_num} --- Regulation Type: {RegType} --- Regulation Name : {RegName} ---\n {AImessage}")
emit('search_complete', {'status': 'completed', 'message': '모든 법규 검색이 완료되었습니다.'})
else:
# [수정됨] failsafe_mode 인자 제거
Rag_Results = search_DB_from_multiple_regions(Translated_query, regions, region_rag_objects, combined_filters)
if session_id in active_sessions:
emit('search_status', {'status': 'ai_processing', 'message': 'AI가 통합 답변을 생성 중...'})
AImessage = RegAI(query, Rag_Results, ResultFile_FolderAddress)
if session_id in active_sessions:
emit('search_result', {'result': AImessage})
emit('search_complete', {'status': 'completed', 'message': '통합 검색이 완료되었습니다.'})
else:
Filtered_search = False
emit('search_status', {'status': 'searching_all', 'message': '전체 법규에서 검색 중...'})
# 필터 없이 검색
# [수정됨] failsafe_mode 인자 제거
Rag_Results = search_DB_from_multiple_regions(Translated_query, regions, region_rag_objects, None)
if session_id in active_sessions:
emit('search_status', {'status': 'ai_processing', 'message': 'AI가 답변을 생성 중...'})
AImessage = RegAI(query, Rag_Results, ResultFile_FolderAddress)
if session_id in active_sessions:
emit('search_result', {'result': AImessage})
emit('search_complete', {'status': 'completed', 'message': '검색이 완료되었습니다.'})
except Exception as e:
print(f"검색 오류: {e}")
traceback.print_exc()
emit('search_error', {'error': str(e), 'message': '검색 중 오류가 발생했습니다.'})
finally:
if session_id in active_sessions:
del active_sessions[session_id]
@socketio.on('cancel_search')
def handle_cancel_search(data):
session_id = data.get('session_id')
if session_id and session_id in active_sessions:
del active_sessions[session_id]
emit('search_cancelled', {'message': '검색이 취소되었습니다.'})
# --- 법규 리스트 ---
@app.route('/get_reg_list', methods=['POST'])
def get_reg_list():
data = request.get_json()
selected_regions = data.get('regions', [])
if not selected_regions:
selected_regions = ["국내", "북미", "유럽"]
all_reg_list_part = []
all_reg_list_section = []
all_reg_list_chapter = []
all_reg_list_jo = []
for region in selected_regions:
rag = region_rag_objects.get(region)
if not rag:
continue
try:
sqlite_conn = rag["sqlite_conn"]
# [수정됨] v02 스키마(regulation, section, chapter, standard)에 맞춰 쿼리
reg_list_part = get_unique_metadata_values(sqlite_conn, "regulation") # 구 regulation_part
reg_list_section = get_unique_metadata_values(sqlite_conn, "section") # 구 regulation_section
reg_list_chapter = get_unique_metadata_values(sqlite_conn, "chapter") # 구 chapter_section
reg_list_jo = get_unique_metadata_values(sqlite_conn, "standard") # 구 jo
if isinstance(reg_list_part, str): reg_list_part = [reg_list_part]
if isinstance(reg_list_section, str): reg_list_section = [reg_list_section]
if isinstance(reg_list_chapter, str): reg_list_chapter = [reg_list_chapter]
if isinstance(reg_list_jo, str): reg_list_jo = [reg_list_jo]
all_reg_list_part.extend(reg_list_part)
all_reg_list_section.extend(reg_list_section)
all_reg_list_chapter.extend(reg_list_chapter)
all_reg_list_jo.extend(reg_list_jo)
except Exception as e:
print(f"[{region}] DB 연결 오류: {e}")
# 자연 정렬 및 중복 제거
unique_reg_list_part = sorted(set(all_reg_list_part), key=reg_embedding_system.natural_sort_key)
unique_reg_list_section = sorted(set(all_reg_list_section), key=reg_embedding_system.natural_sort_key)
unique_reg_list_chapter = sorted(set(all_reg_list_chapter), key=reg_embedding_system.natural_sort_key)
unique_reg_list_jo = sorted(set(all_reg_list_jo), key=reg_embedding_system.natural_sort_key)
# Frontend(HTML)에서는 기존 key(reg_list_part 등)를 그대로 사용할 가능성이 높으므로
# 반환 변수명은 유지하되 내용은 새로운 DB 컬럼에서 가져온 것을 넣습니다.
text_result_part = "\n".join(str(item) for item in unique_reg_list_part)
text_result_section = "\n".join(str(item) for item in unique_reg_list_section)
text_result_chapter = "\n".join(str(item) for item in unique_reg_list_chapter)
text_result_jo = "\n".join(str(item) for item in unique_reg_list_jo)
return jsonify(reg_list_part=text_result_part,
reg_list_section=text_result_section,
reg_list_chapter=text_result_chapter,
reg_list_jo=text_result_jo)
# --- SocketIO ---
@socketio.on('connect')
def handle_connect():
global connected_clients
connected_clients += 1
client_ip = request.remote_addr
if request.headers.get('X-Forwarded-For'):
client_ip = request.headers.get('X-Forwarded-For').split(',')[0].strip()
elif request.headers.get('X-Real-IP'):
client_ip = request.headers.get('X-Real-IP')
elif request.headers.get('CF-Connecting-IP'):
client_ip = request.headers.get('CF-Connecting-IP')
logger.info(f"클라이언트 연결 | IP: {client_ip} | 현재 접속자: {connected_clients}명")
global rag_connection_status_info
socketio.emit('message', {'message': rag_connection_status_info})
@socketio.on('disconnect')
def handle_disconnect():
global connected_clients
connected_clients -= 1
logger.info(f"클라이언트 연결: {connected_clients}명")
def cleanup_connections():
for region, rag in region_rag_objects.items():
try:
rag["sqlite_conn"].close()
logger.info(f"[{region}] DB 연결 종료")
except:
pass
# --- Together AI 분석 ---
def Gemma3_AI_analysis(query_txt, content_txt):
content_txt = "\n".join(doc.page_content for doc in content_txt) if isinstance(content_txt, list) else str(content_txt)
query_txt = str(query_txt)
prompt = lexi_prompts.use_prompt(lexi_prompts.AI_system_prompt, query_txt=query_txt, content_txt=content_txt)
if not client:
return "AI Client가 초기화되지 않았습니다."
try:
response = client.chat.completions.create(
model="moonshotai/Kimi-K2-Instruct-0905",
messages=[{"role": "user", "content": prompt}],
)
AI_Result = response.choices[0].message.content
return AI_Result
except Exception as e:
logger.info(f"Together AI 분석 API 호출 실패: {e}")
traceback.print_exc()
return f"AI 분석 중 오류가 발생했습니다: {e}"
# --- Together AI 번역 ---
def Gemma3_AI_Translate(query_txt):
query_txt = str(query_txt)
prompt = lexi_prompts.use_prompt(lexi_prompts.query_translator, query_txt=query_txt)
if not client:
return query_txt
try:
response = client.chat.completions.create(
model="moonshotai/Kimi-K2-Instruct-0905",
messages=[{"role": "user", "content": prompt}],
)
AI_Result = response.choices[0].message.content
return AI_Result
except Exception as e:
logger.info(f"Together AI 번역 API 호출 실패: {e}")
traceback.print_exc()
return query_txt
# --- 검색 (수정됨) ---
def search_DB_from_multiple_regions(query, selected_regions, region_rag_objects, custom_filters=None):
# [수정됨] failsafe_mode 인자 제거 (v02 함수 정의와 일치시킴)
global Filtered_search
global filters
if not selected_regions:
selected_regions = list(region_rag_objects.keys())
print(f"Translated Query : {query}")
search_filters = custom_filters if custom_filters is not None else filters
has_filters = any(search_filters.get(key, []) for key in search_filters.keys())
print(f"사용된 검색 필터: {search_filters}")
combined_results = []
for region in selected_regions:
rag = region_rag_objects.get(region)
if not rag:
continue
# [수정됨] 키 변경 (ensemble_retriever -> bm25_retriever)
bm25_retriever = rag["bm25_retriever"]
vectorstore = rag["vectorstore"]
sqlite_conn = rag["sqlite_conn"]
if bm25_retriever:
if has_filters:
# [수정됨] v02 시그니처 반영 (ensemble->bm25, failsafe 제거)
results = reg_embedding_system.search_with_metadata_filter(
bm25_retriever=bm25_retriever,
vectorstore=vectorstore,
query=query,
k=search_document_number,
metadata_filter=search_filters,
sqlite_conn=sqlite_conn
)
else:
# [수정됨] v02 시그니처 반영 (retriever->bm25, failsafe 제거)
results = reg_embedding_system.smart_search_vectorstore(
bm25_retriever=bm25_retriever,
vectorstore=vectorstore,
query=query,
k=search_document_number,
sqlite_conn=sqlite_conn,
enable_detailed_search=True
)
print(f"[{region}] 검색 완료: {len(results)}건")
combined_results.extend(results)
return combined_results
# --- 최종 AI ---
def RegAI(query, Rag_Results, ResultFile_FolderAddress):
gc.collect()
AI_Result = "검색 결과가 없습니다." if not Rag_Results else Gemma3_AI_analysis(query, Rag_Results)
return AI_Result
# [수정됨] 법규 타입별 필터 생성 함수 - DB 스키마 변경 반영
def create_filter_by_type(regulation_type, regulation_title):
"""
법규 타입에 따라 적절한 필터 딕셔너리 생성
v02 DB 컬럼: regulation, section, chapter, standard
"""
filter_dict = {
"regulation": [],
"section": [],
"chapter": [],
"standard": []
}
# [수정됨] 기존 Frontend 타입 -> v02 DB 컬럼 매핑
type_mapping = {
"regulation_part": "regulation",
"regulation_section": "section",
"chapter_section": "chapter",
"jo": "standard",
# 축약형 지원
"part": "regulation",
"section": "section",
"chapter": "chapter",
}
filter_key = type_mapping.get(regulation_type, "regulation")
filter_dict[filter_key].append(regulation_title)
return filter_dict
# 법규들을 타입별로 그룹화하는 함수
def group_regulations_by_type(selected_regulations):
grouped = {
"part": [],
"section": [],
"chapter": [],
"jo": []
}
for regulation in selected_regulations:
regulation_type = regulation.get('type', 'part')
regulation_title = regulation.get('title', '')
if regulation_title and regulation_type in grouped:
grouped[regulation_type].append(regulation_title)
return grouped
# [수정됨] 통합 필터 생성 함수 - DB 키 변경 반영
def create_combined_filters(grouped_regulations):
"""그룹화된 법규들로부터 통합 필터 생성 (v02 DB 키 사용)"""
filters = {
"regulation": grouped_regulations["part"], # regulation_part -> regulation
"section": grouped_regulations["section"], # regulation_section -> section
"chapter": grouped_regulations["chapter"], # chapter_section -> chapter
"standard": grouped_regulations["jo"] # jo -> standard
}
return filters
def get_unique_metadata_values(
sqlite_conn: sqlite3.Connection,
key_name: str,
partial_match: Optional[str] = None
) -> List[str]:
"""SQLite 고유 값 반환"""
text_result = ""
if not sqlite_conn:
return text_result
cursor = sqlite_conn.cursor()
sql_query = f"SELECT DISTINCT `{key_name}` FROM documents"
params = []
if partial_match:
sql_query += f" WHERE `{key_name}` LIKE ?"
params.append(f"%{partial_match}%")
try:
cursor.execute(sql_query, params)
unique_values = [row[0] for row in cursor.fetchall() if row[0] is not None]
unique_values.sort(key=reg_embedding_system.natural_sort_key)
text_result = "\n".join(str(value) for value in unique_values)
return text_result
except Exception as e:
print(f"[에러] 고유 값 검색 실패 ({key_name}): {e}")
return text_result
# --- 실행 ---
if __name__ == '__main__':
threading.Thread(target=load_rag_objects, daemon=True).start()
time.sleep(2)
socketio.emit('message', {'message': '데이터 로딩 시작...'})
socketio.run(app, host='0.0.0.0', port=7860, debug=False)
else:
import atexit
loading_thread = threading.Thread(target=load_rag_objects, daemon=True)
loading_thread.start()
atexit.register(cleanup_connections) |