SandhyaMadhunagula commited on
Commit
be89e03
·
verified ·
1 Parent(s): 0c0e6f2

Upload 5 files

Browse files
Files changed (5) hide show
  1. Dockerfile +28 -0
  2. app.py +174 -0
  3. main.py +0 -0
  4. requirements.txt +95 -0
  5. view_db.py +12 -0
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies (needed for some AI libraries)
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy the requirements file into the container
13
+ COPY requirements.txt .
14
+
15
+ # Install any needed packages specified in requirements.txt
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy the rest of your application code
19
+ COPY . .
20
+
21
+ # Create a directory for the graph if it doesn't exist
22
+ RUN mkdir -p static
23
+
24
+ # Flask apps on Hugging Face Spaces must run on port 7860
25
+ EXPOSE 7860
26
+
27
+ # Run app.py when the container launches
28
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, redirect, url_for, session
2
+ import networkx as nx
3
+ from pyvis.network import Network
4
+ import os, re, pickle
5
+ from dotenv import load_dotenv
6
+ from PyPDF2 import PdfReader
7
+ from docx import Document
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+ import torch
10
+ import csv
11
+ from flask import Response
12
+ import io
13
+
14
+ app = Flask(__name__)
15
+ app.secret_key = "secret_key_for_session"
16
+
17
+ model_name = "Babelscape/rebel-large"
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+
21
+
22
+ load_dotenv() # This loads the variables from .env
23
+ HF_TOKEN = os.getenv("HF_TOKEN")
24
+ rebel_tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
25
+
26
+ #rebel_tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
27
+ rebel_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN, low_cpu_mem_usage=True).to(device)
28
+
29
+
30
+ DB_FILE = "graph_database.pkl"
31
+
32
+ def save_db(graph):
33
+ with open(DB_FILE, "wb") as f:
34
+ pickle.dump(graph, f)
35
+
36
+ def load_db():
37
+ if os.path.exists(DB_FILE):
38
+ try:
39
+ with open(DB_FILE, "rb") as f:
40
+ return pickle.load(f)
41
+ except: return nx.DiGraph()
42
+ return nx.DiGraph()
43
+
44
+ G = load_db()
45
+
46
+ def extract_triples(text):
47
+ inputs = rebel_tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(device)
48
+ gen_kwargs = {"max_length": 128, "length_penalty": 0, "num_beams": 1, "num_return_sequences": 1}
49
+ generated_tokens = rebel_model.generate(**inputs, **gen_kwargs)
50
+ decoded = rebel_tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)[0]
51
+
52
+ triples = []
53
+ current_subject, current_relation, current_object = "", "", ""
54
+ current_state = ""
55
+
56
+ # ADD THESE TWO LINES TO FIX THE "FIRST WORD" PROBLEM
57
+ clean_decoded = decoded.replace("<s>", "").replace("</s>", "")
58
+ clean_decoded = clean_decoded.replace("<triplet>", " <triplet> ").replace("<subj>", " <subj> ").replace("<obj>", " <obj> ")
59
+
60
+ # CHANGE THIS LOOP TO USE clean_decoded
61
+ for token in clean_decoded.split():
62
+ if token == "<triplet>":
63
+ current_state = "s"
64
+ if current_subject and current_relation and current_object:
65
+ triples.append((current_subject.strip(), current_relation.strip(), current_object.strip()))
66
+ current_subject, current_relation, current_object = "", "", ""
67
+ elif token == "<subj>": current_state = "o"
68
+ elif token == "<obj>": current_state = "r"
69
+ else:
70
+ if current_state == "s": current_subject += " " + token
71
+ elif current_state == "o": current_object += " " + token
72
+ elif current_state == "r": current_relation += " " + token
73
+
74
+ if current_subject and current_relation and current_object:
75
+ triples.append((current_subject.strip(), current_relation.strip(), current_object.strip()))
76
+ return triples
77
+
78
+ def visualize_graph():
79
+ net = Network(height="600px", width="100%", directed=True, bgcolor="#ffffff", font_color="black", cdn_resources='remote')
80
+ net.force_atlas_2based(gravity=-50, central_gravity=0.01, spring_length=150, damping=0.4)
81
+
82
+ # CRITICAL FIX: Loop through nodes and edges to draw them
83
+ for node in G.nodes():
84
+ net.add_node(node, label=node, color="#00d2ff", size=25, shadow={'enabled': True, 'color': 'rgba(0,210,255,0.6)', 'size': 10})
85
+ for source, target, data in G.edges(data=True):
86
+ net.add_edge(source, target, label=data.get("label", ""), color="#a29bfe")
87
+
88
+ if not os.path.exists("static"): os.makedirs("static")
89
+ net.save_graph("static/graph.html")
90
+
91
+ @app.route("/", methods=["GET", "POST"])
92
+ def index():
93
+ global G
94
+ answer = None
95
+ user_query = ""
96
+ text = session.get('user_text', "")
97
+
98
+ if request.method == "POST":
99
+ # 1. HANDLE FILE UPLOAD OR TEXT BOX
100
+ if "file" in request.files and request.files["file"].filename != "":
101
+ file = request.files["file"]
102
+ ext = file.filename.split('.')[-1].lower()
103
+ if ext == "pdf":
104
+ reader = PdfReader(file)
105
+ text = " ".join([page.extract_text() for page in reader.pages])
106
+ elif ext == "docx":
107
+ text = " ".join([p.text for p in Document(file).paragraphs])
108
+ elif ext == "txt":
109
+ text = file.read().decode("utf-8")
110
+ elif "text" in request.form and request.form["text"].strip():
111
+ text = request.form["text"]
112
+
113
+ # 2. PROCESS DATA (Only if we have new text)
114
+ if text and "query" not in request.form:
115
+ session['user_text'] = text
116
+ sentences = [s.strip() for s in re.split(r'[\n.!?]', text) if len(s.strip()) > 10]
117
+ print(f"--- 🚀 AI is extracting from {len(sentences)} sentences ---")
118
+ for i, sent in enumerate(sentences):
119
+ print(f"📄 Processing {i+1}/{len(sentences)}...")
120
+ for s, r, o in extract_triples(sent):
121
+ G.add_edge(s.title().strip(), o.title().strip(), label=r.strip())
122
+ save_db(G)
123
+ visualize_graph()
124
+
125
+ # 3. HANDLE SEARCH QUERY
126
+ if "query" in request.form:
127
+ user_query = request.form["query"].strip()
128
+ keywords = [w.lower() for w in user_query.split() if len(w) > 3]
129
+ results = []
130
+ for node in G.nodes():
131
+ if any(k in node.lower() for k in keywords):
132
+ for n in G.successors(node):
133
+ results.append(f"<b>{node}</b> {G[node][n]['label']} <b>{n}</b>")
134
+ for p in G.predecessors(node):
135
+ results.append(f"<b>{p}</b> {G[p][node]['label']} <b>{node}</b>")
136
+ answer = " • " + "<br> • ".join(list(set(results))[:8]) if results else f"Nothing found for '{user_query}'."
137
+
138
+ db_triples = [{"s": s, "r": d['label'], "o": t} for s, t, d in G.edges(data=True)]
139
+ return render_template("index.html", answer=answer, graph=os.path.exists("static/graph.html"), user_query=user_query, user_text=text, db_triples=db_triples)
140
+
141
+ @app.route("/export_csv")
142
+ def export_csv():
143
+ # 1. Create a string buffer to hold CSV data
144
+ output = io.StringIO()
145
+ writer = csv.writer(output)
146
+
147
+ # 2. Write the Header
148
+ writer.writerow(['Subject', 'Relationship', 'Object'])
149
+
150
+ # 3. Write the Data from the Graph G
151
+ for s, t, d in G.edges(data=True):
152
+ writer.writerow([s, d.get('label', ''), t])
153
+
154
+ # 4. Prepare the response for download
155
+ output.seek(0)
156
+ return Response(
157
+ output,
158
+ mimetype="text/csv",
159
+ headers={"Content-disposition": "attachment; filename=knowledge_graph.csv"}
160
+ )
161
+ @app.route("/clear")
162
+ def clear_db():
163
+ global G
164
+ G = nx.DiGraph()
165
+ session.clear()
166
+ if os.path.exists(DB_FILE): os.remove(DB_FILE)
167
+ if os.path.exists("static/graph.html"): os.remove("static/graph.html")
168
+ return redirect(url_for('index'))
169
+
170
+ #if __name__ == "__main__":
171
+ # app.run(debug=True)
172
+ if __name__ == "__main__":
173
+ # 0.0.0.0 makes it accessible to the internet
174
+ app.run(host="0.0.0.0", port=7860)
main.py ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.13.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.13.3
4
+ aiosignal==1.4.0
5
+ annotated-doc==0.0.4
6
+ anyio==4.12.1
7
+ asttokens==3.0.1
8
+ attrs==25.4.0
9
+ blinker==1.9.0
10
+ certifi==2026.1.4
11
+ charset-normalizer==3.4.4
12
+ click==8.3.1
13
+ colorama==0.4.6
14
+ contourpy==1.3.3
15
+ cycler==0.12.1
16
+ datasets==2.14.5
17
+ decorator==5.2.1
18
+ dill==0.3.7
19
+ executing==2.2.1
20
+ filelock==3.20.3
21
+ Flask==3.1.3
22
+ fonttools==4.61.1
23
+ frozenlist==1.8.0
24
+ fsspec==2023.6.0
25
+ h11==0.16.0
26
+ hf-xet==1.3.2
27
+ httpcore==1.0.9
28
+ httpx==0.28.1
29
+ huggingface_hub==1.6.0
30
+ idna==3.11
31
+ ipython==9.10.0
32
+ ipython_pygments_lexers==1.1.1
33
+ itsdangerous==2.2.0
34
+ jedi==0.19.2
35
+ Jinja2==3.1.6
36
+ joblib==1.5.3
37
+ jsonpickle==4.1.1
38
+ kiwisolver==1.4.9
39
+ lxml==6.0.2
40
+ markdown-it-py==4.0.0
41
+ MarkupSafe==3.0.3
42
+ matplotlib==3.10.8
43
+ matplotlib-inline==0.2.1
44
+ mdurl==0.1.2
45
+ mpmath==1.3.0
46
+ multidict==6.7.1
47
+ multiprocess==0.70.15
48
+ networkx==3.6.1
49
+ numpy==1.26.4
50
+ packaging==26.0
51
+ pandas==1.5.3
52
+ parso==0.8.6
53
+ pillow==12.0.0
54
+ prompt_toolkit==3.0.52
55
+ propcache==0.4.1
56
+ psutil==7.2.2
57
+ pure_eval==0.2.3
58
+ pyarrow==11.0.0
59
+ Pygments==2.19.2
60
+ pyparsing==3.3.2
61
+ PyPDF2==3.0.1
62
+ python-dateutil==2.9.0.post0
63
+ python-docx==1.2.0
64
+ pytz==2025.2
65
+ pyvis==0.3.2
66
+ PyYAML==6.0.3
67
+ regex==2026.1.15
68
+ requests==2.32.5
69
+ rich==14.3.3
70
+ safetensors==0.7.0
71
+ scikit-learn==1.8.0
72
+ scipy==1.17.0
73
+ sentencepiece==0.2.1
74
+ seqeval==1.2.2
75
+ shellingham==1.5.4
76
+ six==1.17.0
77
+ stack-data==0.6.3
78
+ sympy==1.14.0
79
+ threadpoolctl==3.6.0
80
+ tokenizers==0.22.2
81
+ torch==2.10.0+cpu
82
+ torchaudio==2.1.2+cpu
83
+ torchvision==0.16.2+cpu
84
+ tqdm==4.67.3
85
+ traitlets==5.14.3
86
+ transformers==5.3.0
87
+ typer==0.24.1
88
+ typer-slim==0.21.1
89
+ typing_extensions==4.15.0
90
+ tzdata==2025.3
91
+ urllib3==2.6.3
92
+ wcwidth==0.6.0
93
+ Werkzeug==3.1.6
94
+ xxhash==3.6.0
95
+ yarl==1.22.0
view_db.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import pandas as pd
3
+
4
+ with open("graph_database.pkl", "rb") as f:
5
+ G = pickle.load(f)
6
+
7
+ # Convert edges to a list of dictionaries
8
+ edge_list = [{"Subject": s, "Relation": d['label'], "Object": t} for s, t, d in G.edges(data=True)]
9
+
10
+ # Display as a table
11
+ df = pd.DataFrame(edge_list)
12
+ print(df)