|
1 |
| -"""This modules experiments building logic from scratch, without langchain.""" |
| 1 | +"""This modules experiments building logic from scratch, without langchain. |
| 2 | +
|
| 3 | +It needs some refactoring :) |
| 4 | +""" |
2 | 5 | from dataclasses import dataclass
|
3 | 6 | import logging
|
4 | 7 | from code_it.code_editor.python_editor import PythonCodeEditor
|
|
14 | 17 | logger = logging.getLogger(__name__)
|
15 | 18 |
|
16 | 19 | ANSWER_PATTERN = r"[a-zA-Z]+"
|
17 |
| -DEPENDENCY_BLACKLIST = set(["random", "json"]) |
18 | 20 |
|
| 21 | +NO_SAMPLING = "NO_SAMPLING" |
| 22 | +PYLINT = "PYLINT" |
| 23 | + |
| 24 | +DEPENDENCY_BLACKLIST = set(["random", "json"]) |
| 25 | +SUPPORTED_SAMPLING_STRATEGIES = set([PYLINT, NO_SAMPLING]) |
19 | 26 |
|
20 | 27 | def _trim_md(code_editor):
|
21 | 28 | if code_editor.source_code:
|
22 | 29 | code_editor.source_code[0] = code_editor.source_code[0].replace("```python", "")
|
23 | 30 | code_editor.source_code[-1] = code_editor.source_code[-1].replace("```", "")
|
24 | 31 | code_editor.overwrite_code(code_editor.display_code())
|
25 | 32 |
|
| 33 | +# TODO: add validation to the config |
26 | 34 |
|
27 | 35 | @dataclass
|
28 | 36 | class TaskExecutionConfig:
|
29 | 37 | execute_code = True
|
30 | 38 | install_dependencies = True
|
31 | 39 | apply_linter = True
|
32 | 40 | check_package_is_in_pypi = True
|
| 41 | + log_to_stdout = True |
| 42 | + coding_samples = 10 |
| 43 | + code_sampling_strategy = "PYLINT" |
| 44 | + sampling_temperature_multipler = 0.1 |
33 | 45 | dependency_samples = 3
|
34 |
| - max_refactor_attempts = 5 |
| 46 | + max_coding_attempts = 5 |
35 | 47 | dependency_install_attempts = 5
|
36 | 48 | planner_temperature = 0
|
37 | 49 | coder_temperature = 0
|
@@ -130,37 +142,74 @@ def execute(self, task: str):
|
130 | 142 | logger.info("Installed dependencies successfully!")
|
131 | 143 |
|
132 | 144 | # Coding
|
133 |
| - for i in range(self.config.max_refactor_attempts): |
134 |
| - logger.info("Coding, attempt: %s", i) |
135 |
| - refactored = self.coder.execute_task( |
136 |
| - source_code=self.code_editor.display_code(), objective=task, plan="\n".join(plan) |
137 |
| - ) |
138 |
| - self.code_editor.overwrite_code(refactored) |
139 |
| - _trim_md(self.code_editor) |
| 145 | + if self.config.code_sampling_strategy == NO_SAMPLING: |
| 146 | + for i in range(self.config.max_coding_attempts): |
| 147 | + logger.info("Coding, attempt: %s", i) |
| 148 | + new_code = self.coder.execute_task( |
| 149 | + source_code=self.code_editor.display_code(), objective=task, plan="\n".join(plan) |
| 150 | + ) |
| 151 | + self.code_editor.overwrite_code(new_code) |
| 152 | + _trim_md(self.code_editor) |
140 | 153 |
|
141 |
| - logger.info(self.code_editor.display_code()) |
| 154 | + logger.info(self.code_editor.display_code()) |
142 | 155 |
|
143 |
| - if self.config.apply_linter: |
144 |
| - logger.info("Applying linter...") |
145 |
| - (pylint_stdout, _) = lint.py_run(self.code_editor.filename, return_std=True) |
146 |
| - pylint_stdout = pylint_stdout.getvalue() |
147 |
| - logger.info(pylint_stdout) |
| 156 | + if self.config.apply_linter: |
| 157 | + logger.info("Applying linter...") |
| 158 | + (pylint_stdout, _) = lint.py_run(self.code_editor.filename, return_std=True) |
| 159 | + pylint_stdout = pylint_stdout.getvalue() |
| 160 | + logger.info(pylint_stdout) |
| 161 | + |
| 162 | + new_code = self.linter.execute_task( |
| 163 | + source_code=self.code_editor.display_code(), |
| 164 | + stdout=pylint_stdout, |
| 165 | + ) |
| 166 | + logger.warn("Linted code: %s", new_code) |
| 167 | + if new_code: |
| 168 | + self.code_editor.overwrite_code(new_code) |
| 169 | + |
| 170 | + if not self.config.execute_code: |
| 171 | + return self.code_editor.display_code() |
148 | 172 |
|
149 |
| - new_code = self.linter.execute_task( |
150 |
| - source_code=self.code_editor.display_code(), |
151 |
| - stdout=pylint_stdout, |
| 173 | + result = self.code_editor.run_code() |
| 174 | + |
| 175 | + if "Succeeded" in result: |
| 176 | + break |
| 177 | + |
| 178 | + elif self.config.code_sampling_strategy == PYLINT: |
| 179 | + coding_samples = [] |
| 180 | + for i in range(self.config.code_sampling_strategy): |
| 181 | + self.coder.llm.set_parameter("temperature", i * self.config.sampling_temperature_multipler) |
| 182 | + logger.info("Coding sample: %s", i) |
| 183 | + new_code = self.coder.execute_task( |
| 184 | + source_code=self.code_editor.display_code(), objective=task, plan="\n".join(plan) |
152 | 185 | )
|
153 |
| - logger.warn("Linted code: %s", new_code) |
154 |
| - if new_code: |
155 |
| - self.code_editor.overwrite_code(new_code) |
| 186 | + coding_samples.append({"code": new_code}) |
| 187 | + self.code_editor.overwrite_code(new_code) |
| 188 | + _trim_md(self.code_editor) |
| 189 | + |
| 190 | + logger.info(self.code_editor.display_code()) |
| 191 | + logger.info("Applying linter...") |
156 | 192 |
|
| 193 | + (pylint_stdout, _) = lint.py_run(self.code_editor.filename, return_std=True) |
| 194 | + pylint_stdout = pylint_stdout.getvalue() |
| 195 | + split_1 = pylint_stdout.split("Your code has been rated at ")[0] |
| 196 | + linting_score_str = split_1.split("/")[0] |
| 197 | + score = float(linting_score_str) |
| 198 | + coding_samples[i]["score"] = score |
| 199 | + logger.info("Sample score: %s", score) |
| 200 | + |
| 201 | + coding_samples.sort(key=lambda x: x["score"], reverse=True) |
| 202 | + highest_score = coding_samples[0] |
| 203 | + logger.info("Score of highest sample: %s", highest_score["score"]) |
| 204 | + self.code_editor.overwrite_code(highest_score["code"]) |
157 | 205 | if not self.config.execute_code:
|
158 | 206 | return self.code_editor.display_code()
|
159 | 207 |
|
160 | 208 | result = self.code_editor.run_code()
|
161 | 209 |
|
162 |
| - if "Succeeded" in result: |
163 |
| - break |
| 210 | + else: |
| 211 | + raise ValueError("Invalid Sampling Strategy") |
| 212 | + |
164 | 213 |
|
165 | 214 | logger.info("Finished generating code!")
|
166 | 215 |
|
|
0 commit comments