| from .base_agent import BaseAgent | |
| from prompt.template import CREATE_CHART_PROMPT | |
| class Chart(BaseAgent): | |
| def __init__(self, llm): | |
| super().__init__(llm) | |
| def create_single_chart(self, paper_content: str, existing_charts: str, user_prompt: str=''): | |
| prompt = CREATE_CHART_PROMPT.format(paper_content=paper_content, existing_charts=existing_charts, user_prompt=user_prompt) | |
| return self.llm.generate(prompt) | |
| def create_charts(self, paper_content: str, chart_num: int, user_prompt: str=''): | |
| existing_charts = '' | |
| charts = [] | |
| for i in range(chart_num): | |
| chart = self.create_single_chart(paper_content, existing_charts, user_prompt) | |
| charts.append(chart) | |
| existing_charts = '\n---\n'.join(charts) | |
| return charts | |