Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,67 +1,28 @@
|
|
| 1 |
-
|
| 2 |
-
from
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
# MongoDB 연결
|
| 7 |
-
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
|
| 8 |
-
db = client["two_tower_model"]
|
| 9 |
-
product_collection = db["product_tower"]
|
| 10 |
-
user_collection = db["user_tower"]
|
| 11 |
-
product_embedding_collection = db["product_embeddings"]
|
| 12 |
-
user_embedding_collection = db["user_embeddings"]
|
| 13 |
-
|
| 14 |
-
# 모델 학습
|
| 15 |
-
def train_model_and_embed():
|
| 16 |
-
product_model = None # Define or load your model
|
| 17 |
-
anchor_data, positive_data, negative_data = load_training_data()
|
| 18 |
-
trained_model = train_triplet_model(product_model, anchor_data, positive_data, negative_data)
|
| 19 |
-
|
| 20 |
-
return trained_model
|
| 21 |
-
|
| 22 |
-
# 데이터 임베딩 및 저장
|
| 23 |
-
def embed_and_save():
|
| 24 |
-
all_products = list(product_collection.find())
|
| 25 |
-
all_users = list(user_collection.find())
|
| 26 |
-
|
| 27 |
-
for product_data in all_products:
|
| 28 |
-
embedding = embed_product_data(product_data)
|
| 29 |
-
product_embedding_collection.update_one(
|
| 30 |
-
{"product_id": product_data["product_id"]},
|
| 31 |
-
{"$set": {"embedding": embedding.tolist()}},
|
| 32 |
-
upsert=True
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
for user_data in all_users:
|
| 36 |
-
embedding = embed_user_data(user_data)
|
| 37 |
-
user_embedding_collection.update_one(
|
| 38 |
-
{"user_id": user_data["user_id"]},
|
| 39 |
-
{"$set": {"embedding": embedding.tolist()}},
|
| 40 |
-
upsert=True
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
# 추천 실행
|
| 44 |
-
def recommend(user_id, top_n=5):
|
| 45 |
-
user_embedding_data = user_embedding_collection.find_one({"user_id": user_id})
|
| 46 |
-
if not user_embedding_data:
|
| 47 |
-
print(f"No embedding found for user_id: {user_id}")
|
| 48 |
-
return []
|
| 49 |
-
|
| 50 |
-
user_embedding = np.array(user_embedding_data["embedding"])
|
| 51 |
-
all_products = list(product_embedding_collection.find())
|
| 52 |
-
product_ids = [prod["product_id"] for prod in all_products]
|
| 53 |
-
product_embeddings = [prod["embedding"] for prod in all_products]
|
| 54 |
-
|
| 55 |
-
recommendations = calculate_cosine_similarity(user_embedding, product_embeddings, product_ids, top_n)
|
| 56 |
-
print(f"Recommendations for user {user_id}: {recommendations}")
|
| 57 |
-
return recommendations
|
| 58 |
-
|
| 59 |
-
# 실행
|
| 60 |
if __name__ == "__main__":
|
| 61 |
-
|
| 62 |
-
train_model_and_embed()
|
| 63 |
-
embed_and_save()
|
| 64 |
-
|
| 65 |
-
# Recommend products for a user
|
| 66 |
-
user_id = "정우석"
|
| 67 |
-
recommend(user_id, top_n=3)
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from calculate_cosine_similarity import (
|
| 3 |
+
find_most_similar_anchor,
|
| 4 |
+
find_most_similar_product,
|
| 5 |
+
recommend_shop_product,
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
# 사용자 ID 입력
|
| 10 |
+
user_id = "user_123" # 사용자 ID 예시
|
| 11 |
+
|
| 12 |
+
# Step 1: 사용자와 가장 유사한 anchor 찾기
|
| 13 |
+
print(f"Finding the most similar anchor for user {user_id}...")
|
| 14 |
+
most_similar_anchor, anchor_embedding = find_most_similar_anchor(user_id)
|
| 15 |
+
print(f"Most similar anchor: {most_similar_anchor}")
|
| 16 |
+
|
| 17 |
+
# Step 2: anchor와 가장 유사한 상품 찾기
|
| 18 |
+
print("Finding the most similar product to the anchor...")
|
| 19 |
+
most_similar_product, similar_product_embedding = find_most_similar_product(anchor_embedding)
|
| 20 |
+
print(f"Most similar product to anchor: {most_similar_product}")
|
| 21 |
+
|
| 22 |
+
# Step 3: 쇼핑몰 상품 추천
|
| 23 |
+
print("Recommending the best shop product...")
|
| 24 |
+
recommended_product_id = recommend_shop_product(similar_product_embedding)
|
| 25 |
+
print(f"Recommended shop product ID: {recommended_product_id}")
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
if __name__ == "__main__":
|
| 28 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|