Polarisailabs commited on
Commit
45b61b4
·
verified ·
1 Parent(s): 8b9d3cc

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -292
app.py CHANGED
@@ -1,292 +1,45 @@
1
- import torch
2
- import torch.nn as nn
3
- import tiktoken
4
- import gradio as gr
5
-
6
- # ============== Model Classes ==============
7
- class PolarisAIMultiHeadAttention(nn.Module):
8
- def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
9
- super().__init__()
10
- assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
11
- self.d_out = d_out
12
- self.num_heads = num_heads
13
- self.head_dim = d_out // num_heads
14
- self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
15
- self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
16
- self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
17
- self.W_output = nn.Linear(d_out, d_out, bias=qkv_bias)
18
- self.dropout = nn.Dropout(dropout)
19
- self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
20
-
21
- def split_heads(self, x):
22
- seq_len, d_out = x.shape
23
- x = x.view(seq_len, self.num_heads, self.head_dim)
24
- return x.transpose(0, 1)
25
-
26
- def combine_heads(self, x):
27
- num_heads, seq_len, head_dim = x.shape
28
- x = x.transpose(0, 1)
29
- return x.contiguous().view(seq_len, num_heads * head_dim)
30
-
31
- def forward(self, x):
32
- num_tokens, d_in = x.shape
33
- allqueries = self.W_query(x)
34
- allkeys = self.W_key(x)
35
- allvalues = self.W_value(x)
36
- queries_heads = self.split_heads(allqueries)
37
- keys_heads = self.split_heads(allkeys)
38
- values_heads = self.split_heads(allvalues)
39
- attention_scores = queries_heads @ keys_heads.transpose(-2, -1)
40
- masked = attention_scores.masked_fill(
41
- self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
42
- )
43
- attention_weights = torch.softmax(masked / self.head_dim**0.5, dim=-1)
44
- dropout_attention_weights = self.dropout(attention_weights)
45
- context_heads = dropout_attention_weights @ values_heads
46
- context_combined = self.combine_heads(context_heads)
47
- return self.W_output(context_combined)
48
-
49
-
50
- class PolarisAILayerNorm(nn.Module):
51
- def __init__(self, emb_dim):
52
- super().__init__()
53
- self.eps = 1e-5
54
- self.scale = nn.Parameter(torch.ones(emb_dim))
55
- self.shift = nn.Parameter(torch.zeros(emb_dim))
56
-
57
- def forward(self, x):
58
- mean = x.mean(dim=-1, keepdim=True)
59
- var = x.var(dim=-1, keepdim=True, unbiased=False)
60
- norm_x = (x - mean) / torch.sqrt(var + self.eps)
61
- return self.scale * norm_x + self.shift
62
-
63
-
64
- class PolarisAIGELUActivation(nn.Module):
65
- def __init__(self):
66
- super().__init__()
67
-
68
- def forward(self, x):
69
- return 0.5 * x * (1 + torch.tanh(
70
- torch.sqrt(torch.tensor(2.0 / torch.pi)) *
71
- (x + 0.044715 * torch.pow(x, 3))
72
- ))
73
-
74
-
75
- class PolarisAIFeedForwardNetwork(nn.Module):
76
- def __init__(self, cfg):
77
- super().__init__()
78
- self.layers = nn.Sequential(
79
- nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
80
- PolarisAIGELUActivation(),
81
- nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
82
- )
83
-
84
- def forward(self, x):
85
- return self.layers(x)
86
-
87
-
88
- class PolarisAITransformerBlock(nn.Module):
89
- def __init__(self, cfg):
90
- super().__init__()
91
- self.att = PolarisAIMultiHeadAttention(
92
- d_in=cfg["emb_dim"], d_out=cfg["emb_dim"],
93
- context_length=cfg["context_length"], num_heads=cfg["n_heads"],
94
- dropout=cfg["drop_rate"], qkv_bias=cfg["qkv_bias"])
95
- self.ff = PolarisAIFeedForwardNetwork(cfg)
96
- self.norm1 = PolarisAILayerNorm(cfg["emb_dim"])
97
- self.norm2 = PolarisAILayerNorm(cfg["emb_dim"])
98
- self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
99
-
100
- def forward(self, x):
101
- shortcut = x
102
- x = self.norm1(x)
103
- x = self.att(x)
104
- x = self.drop_shortcut(x)
105
- x = x + shortcut
106
- shortcut = x
107
- x = self.norm2(x)
108
- x = self.ff(x)
109
- x = self.drop_shortcut(x)
110
- return x + shortcut
111
-
112
-
113
- class PolarisAIPlatformModel(nn.Module):
114
- def __init__(self, cfg):
115
- super().__init__()
116
- self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
117
- self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
118
- self.drop_emb = nn.Dropout(cfg["drop_rate"])
119
- self.trf_blocks = nn.Sequential(
120
- *[PolarisAITransformerBlock(cfg) for _ in range(cfg["n_layers"])])
121
- self.final_norm = PolarisAILayerNorm(cfg["emb_dim"])
122
- self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
123
- self.cfg = cfg
124
-
125
- def forward(self, in_idx):
126
- seq_len = in_idx.shape[0]
127
- tok_embeds = self.tok_emb(in_idx)
128
- pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
129
- x = tok_embeds + pos_embeds
130
- x = self.drop_emb(x)
131
- x = self.trf_blocks(x)
132
- x = self.final_norm(x)
133
- return self.out_head(x)
134
-
135
-
136
- # ============== Generation Functions ==============
137
- def generate_text_simple(model, idx, max_new_tokens, context_size):
138
- for _ in range(max_new_tokens):
139
- idx_cond = idx[-context_size:]
140
- with torch.no_grad():
141
- logits = model(idx_cond)
142
- logits = logits[-1, :]
143
- probas = torch.softmax(logits, dim=-1)
144
- idx_next = torch.argmax(probas).unsqueeze(0)
145
- idx = torch.cat((idx, idx_next), dim=0)
146
- return idx
147
-
148
-
149
- def generate_text_with_temperature(model, idx, max_new_tokens, context_size, temperature=1.0, top_k=None):
150
- for _ in range(max_new_tokens):
151
- idx_cond = idx[-context_size:]
152
- with torch.no_grad():
153
- logits = model(idx_cond)
154
- logits = logits[-1, :]
155
- if temperature > 0:
156
- logits = logits / temperature
157
- if top_k is not None and top_k > 0:
158
- top_k = min(top_k, logits.size(-1))
159
- values, indices = torch.topk(logits, top_k)
160
- logits = torch.full_like(logits, float('-inf'))
161
- logits.scatter_(-1, indices, values)
162
- probas = torch.softmax(logits, dim=-1)
163
- idx_next = torch.multinomial(probas, num_samples=1)
164
- else:
165
- idx_next = torch.argmax(logits).unsqueeze(0)
166
- idx = torch.cat((idx, idx_next), dim=0)
167
- return idx
168
-
169
-
170
- # ============== Initialize Tokenizer ==============
171
- tokenizer = tiktoken.get_encoding("gpt2")
172
-
173
-
174
- # ============== Gradio Function ==============
175
- def generate_text_gradio(
176
- input_text,
177
- max_new_tokens,
178
- temperature,
179
- top_k,
180
- seed,
181
- decoding_strategy,
182
- vocab_size,
183
- context_length,
184
- emb_dim,
185
- n_heads,
186
- n_layers,
187
- drop_rate,
188
- qkv_bias
189
- ):
190
- if not input_text.strip():
191
- return "Please enter some text to generate from.", ""
192
-
193
- # Validate emb_dim is divisible by n_heads
194
- if emb_dim % n_heads != 0:
195
- return f"Error: Embedding dimension ({emb_dim}) must be divisible by number of heads ({n_heads}).", ""
196
-
197
- # Build config from UI inputs
198
- config = {
199
- "vocab_size": int(vocab_size),
200
- "context_length": int(context_length),
201
- "emb_dim": int(emb_dim),
202
- "n_heads": int(n_heads),
203
- "n_layers": int(n_layers),
204
- "drop_rate": float(drop_rate),
205
- "qkv_bias": bool(qkv_bias)
206
- }
207
-
208
- # Initialize model with user config
209
- torch.manual_seed(int(seed))
210
- model = PolarisAIPlatformModel(config)
211
- model.eval()
212
-
213
- # Calculate model info
214
- total_params = sum(p.numel() for p in model.parameters())
215
- model_size_mb = total_params * 4 / (1024 * 1024)
216
- model_info = f"Parameters: {total_params:,} | Size: {model_size_mb:.2f} MB"
217
-
218
- # Encode input
219
- input_ids = torch.tensor(tokenizer.encode(input_text))
220
-
221
- # Generate
222
- if decoding_strategy == "Greedy":
223
- output_ids = generate_text_simple(model, input_ids, int(max_new_tokens), config["context_length"])
224
- else:
225
- output_ids = generate_text_with_temperature(
226
- model, input_ids, int(max_new_tokens),
227
- config["context_length"], temperature,
228
- int(top_k) if top_k > 0 else None
229
- )
230
-
231
- return tokenizer.decode(output_ids.tolist()), model_info
232
-
233
-
234
- # ============== Gradio Interface ==============
235
- with gr.Blocks(title="PolarisAI Platform",theme=gr.themes.Default(primary_hue='sky')) as PolarisAIPlatform:
236
-
237
- with gr.Row():
238
- # Left Column - Input/Output
239
- with gr.Column(scale=2):
240
- input_text = gr.Textbox(
241
- label="Input Text",
242
- placeholder="Enter text here...",
243
- lines=3,
244
- value=""
245
- )
246
- generate_btn = gr.Button("Generate Text", variant="primary", size="lg")
247
- output_text = gr.Textbox(label="Generated Output", lines=8, interactive=False)
248
- model_info_text = gr.Textbox(label="Model Info", interactive=False)
249
-
250
- # Right Column - Parameters
251
- with gr.Column(scale=1):
252
- # Generation Parameters
253
- decoding_strategy = gr.Radio(
254
- ["Greedy", "Temperature Sampling"],
255
- value="Greedy",
256
- label="Decoding Strategy"
257
- )
258
- max_new_tokens = gr.Slider(1, 100, value=10, step=1, label="Max New Tokens")
259
- temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="Temperature")
260
- top_k = gr.Slider(0, 100, value=0, step=1, label="Top-K (0=disabled)")
261
- seed = gr.Number(value=123, label="Random Seed", precision=0)
262
-
263
- # Model Configuration Parameters
264
- vocab_size = gr.Number(value=50257, label="Vocab Size", precision=0)
265
- context_length = gr.Number(value=1024, label="Context Length", precision=0)
266
- emb_dim = gr.Number(value=768, label="Embedding Dimension", precision=0)
267
- n_heads = gr.Number(value=12, label="Number of Heads", precision=0)
268
- n_layers = gr.Number(value=12, label="Number of Layers", precision=0)
269
- drop_rate = gr.Slider(0.0, 0.5, value=0.1, step=0.01, label="Dropout Rate")
270
- qkv_bias = gr.Checkbox(value=False, label="QKV Bias")
271
-
272
- # Connect button
273
- generate_btn.click(
274
- generate_text_gradio,
275
- inputs=[
276
- input_text, max_new_tokens, temperature, top_k, seed, decoding_strategy,
277
- vocab_size, context_length, emb_dim, n_heads, n_layers, drop_rate, qkv_bias
278
- ],
279
- outputs=[output_text, model_info_text]
280
- )
281
-
282
- # Submit on Enter
283
- input_text.submit(
284
- generate_text_gradio,
285
- inputs=[
286
- input_text, max_new_tokens, temperature, top_k, seed, decoding_strategy,
287
- vocab_size, context_length, emb_dim, n_heads, n_layers, drop_rate, qkv_bias
288
- ],
289
- outputs=[output_text, model_info_text]
290
- )
291
-
292
- PolarisAIPlatform.launch()
 
1
+ _H='custom'
2
+ _G='primary'
3
+ _F='e.g., business, technology, sports, entertainment'
4
+ _E='Custom Labels (for custom classification)'
5
+ _D='Classification Type:'
6
+ _C='sentiment'
7
+ _B='Spam'
8
+ _A='Sentiment'
9
+ import os,gradio as gr
10
+ from openai import OpenAI
11
+ API_KEY=os.environ['API_KEY']
12
+ client=OpenAI(base_url='https://openrouter.ai/api/v1',api_key=API_KEY)
13
+ def classify_text(text,classification_type=_C,custom_labels=''):
14
+ "\n Classify text using OpenRouter's GPT-OSS-20B model\n ";E='content';D='role';B=classification_type;A=text
15
+ if not A.strip():return'Please enter some text to classify.'
16
+ if B==_A:C=f"Classify the sentiment of the following text as Positive, Negative, or Neutral. Only respond with one word: Positive, Negative, or Neutral.\n\nText: {A}"
17
+ elif B==_B:C=f"Classify whether the following text is Spam or Not Spam. Only respond with: Spam or Not Spam.\n\nText: {A}"
18
+ try:F=client.chat.completions.create(model='openai/gpt-oss-20b',messages=[{D:'system',E:'You are a text classification assistant. Provide concise, accurate classifications.'},{D:'user',E:C}],max_tokens=50,temperature=.1,extra_headers={'Authorization':f"Bearer {API_KEY}",'HTTP-Referer':'https://your-app-url.com','X-Title':''});G=F.choices[0].message.content.strip();return f"Classification Result: {G}"
19
+ except Exception as H:return f"Error: {str(H)}"
20
+ def batch_classify(file,classification_type=_C,custom_labels=''):
21
+ '\n Classify multiple texts from uploaded file\n '
22
+ if file is None:return'Please upload a text file.'
23
+ try:
24
+ with open(file.name,'r',encoding='utf-8')as C:D=C.readlines()
25
+ B=[]
26
+ for(E,A)in enumerate(D[:10],1):
27
+ A=A.strip()
28
+ if A:F=classify_text(A,classification_type,custom_labels);B.append(f"{E}. **Text:** {A}\n **Result:** {F}\n")
29
+ return'\n'.join(B)if B else'No text found in file.'
30
+ except Exception as G:return f"Error processing file: {str(G)}"
31
+ with gr.Blocks(title='',theme=gr.themes.Default(primary_hue='sky'))as demo:
32
+ with gr.Tabs():
33
+ with gr.Tab('Single Text'):
34
+ with gr.Row():
35
+ with gr.Column(scale=2):text_input=gr.Textbox(label='',placeholder='Enter text to classify...',lines=4);classification_type=gr.Radio(choices=[_A,_B],value=_A,label=_D);custom_labels=gr.Textbox(label=_E,placeholder=_F,visible=False);classify_btn=gr.Button('Classify Text',variant=_G)
36
+ with gr.Column(scale=2):single_output=gr.Markdown(value='')
37
+ def toggle_custom_labels(choice):return gr.update(visible=choice==_H)
38
+ classification_type.change(toggle_custom_labels,inputs=[classification_type],outputs=[custom_labels]);classify_btn.click(classify_text,inputs=[text_input,classification_type,custom_labels],outputs=[single_output])
39
+ with gr.Tab('Batch Classification'):
40
+ with gr.Row():
41
+ with gr.Column(scale=2):gr.Markdown('Upload a text or csv file:');file_input=gr.File(label='Upload File',file_types=['.txt','.csv']);batch_classification_type=gr.Radio(choices=[_A,_B],value=_A,label=_D);batch_custom_labels=gr.Textbox(label=_E,placeholder=_F,visible=False);batch_classify_btn=gr.Button('🔍 Classify Batch',variant=_G)
42
+ with gr.Column(scale=2):batch_output=gr.Markdown(value='')
43
+ def toggle_batch_custom_labels(choice):return gr.update(visible=choice==_H)
44
+ batch_classification_type.change(toggle_batch_custom_labels,inputs=[batch_classification_type],outputs=[batch_custom_labels]);batch_classify_btn.click(batch_classify,inputs=[file_input,batch_classification_type,batch_custom_labels],outputs=[batch_output])
45
+ if __name__=='__main__':demo.launch(server_name='0.0.0.0',server_port=7860,share=True,show_error=True)