from .base_agent import BaseAgent
from prompt.template import CREATE_CHART_PROMPT
class Chart(BaseAgent):
def __init__(self, llm):
super().__init__(llm)
def create_single_chart(self, paper_content: str, existing_charts: str, user_prompt: str=''):
prompt = CREATE_CHART_PROMPT.format(paper_content=paper_content, existing_charts=existing_charts, user_prompt=user_prompt)
return self.llm.generate(prompt)
def create_charts(self, paper_content: str, chart_num: int, user_prompt: str=''):
existing_charts = ''
charts = []
for i in range(chart_num):
chart = self.create_single_chart(paper_content, existing_charts, user_prompt)
charts.append(chart)
existing_charts = '\n---\n'.join(charts)
return charts