Skip to content

Commit 89d2e66

Browse files
committed
Adding sampling strategies
1 parent 3a1891a commit 89d2e66

File tree

3 files changed

+80
-30
lines changed

3 files changed

+80
-30
lines changed

code_it/langchain/code_it_tool.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ def __init__(self, model_builder: HTTPBaseLLM, config: TaskExecutionConfig) -> N
1111
self.model_builder = model_builder
1212
self.config = config
1313

14-
logging.basicConfig(
15-
level=logging.INFO,
16-
format="%(asctime)s [%(levelname)s] %(message)s",
17-
handlers=[logging.StreamHandler()],
18-
)
14+
if config.log_to_stdout:
15+
logging.basicConfig(
16+
level=logging.INFO,
17+
format="%(asctime)s [%(levelname)s] %(message)s",
18+
handlers=[logging.StreamHandler()],
19+
)
1920

2021
def execute_task(self, task):
2122
code_editor = PythonCodeEditor()

code_it/task_executor.py

+73-24
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
"""This modules experiments building logic from scratch, without langchain."""
1+
"""This modules experiments building logic from scratch, without langchain.
2+
3+
It needs some refactoring :)
4+
"""
25
from dataclasses import dataclass
36
import logging
47
from code_it.code_editor.python_editor import PythonCodeEditor
@@ -14,24 +17,33 @@
1417
logger = logging.getLogger(__name__)
1518

1619
ANSWER_PATTERN = r"[a-zA-Z]+"
17-
DEPENDENCY_BLACKLIST = set(["random", "json"])
1820

21+
NO_SAMPLING = "NO_SAMPLING"
22+
PYLINT = "PYLINT"
23+
24+
DEPENDENCY_BLACKLIST = set(["random", "json"])
25+
SUPPORTED_SAMPLING_STRATEGIES = set([PYLINT, NO_SAMPLING])
1926

2027
def _trim_md(code_editor):
2128
if code_editor.source_code:
2229
code_editor.source_code[0] = code_editor.source_code[0].replace("```python", "")
2330
code_editor.source_code[-1] = code_editor.source_code[-1].replace("```", "")
2431
code_editor.overwrite_code(code_editor.display_code())
2532

33+
# TODO: add validation to the config
2634

2735
@dataclass
2836
class TaskExecutionConfig:
2937
execute_code = True
3038
install_dependencies = True
3139
apply_linter = True
3240
check_package_is_in_pypi = True
41+
log_to_stdout = True
42+
coding_samples = 10
43+
code_sampling_strategy = "PYLINT"
44+
sampling_temperature_multipler = 0.1
3345
dependency_samples = 3
34-
max_refactor_attempts = 5
46+
max_coding_attempts = 5
3547
dependency_install_attempts = 5
3648
planner_temperature = 0
3749
coder_temperature = 0
@@ -130,37 +142,74 @@ def execute(self, task: str):
130142
logger.info("Installed dependencies successfully!")
131143

132144
# Coding
133-
for i in range(self.config.max_refactor_attempts):
134-
logger.info("Coding, attempt: %s", i)
135-
refactored = self.coder.execute_task(
136-
source_code=self.code_editor.display_code(), objective=task, plan="\n".join(plan)
137-
)
138-
self.code_editor.overwrite_code(refactored)
139-
_trim_md(self.code_editor)
145+
if self.config.code_sampling_strategy == NO_SAMPLING:
146+
for i in range(self.config.max_coding_attempts):
147+
logger.info("Coding, attempt: %s", i)
148+
new_code = self.coder.execute_task(
149+
source_code=self.code_editor.display_code(), objective=task, plan="\n".join(plan)
150+
)
151+
self.code_editor.overwrite_code(new_code)
152+
_trim_md(self.code_editor)
140153

141-
logger.info(self.code_editor.display_code())
154+
logger.info(self.code_editor.display_code())
142155

143-
if self.config.apply_linter:
144-
logger.info("Applying linter...")
145-
(pylint_stdout, _) = lint.py_run(self.code_editor.filename, return_std=True)
146-
pylint_stdout = pylint_stdout.getvalue()
147-
logger.info(pylint_stdout)
156+
if self.config.apply_linter:
157+
logger.info("Applying linter...")
158+
(pylint_stdout, _) = lint.py_run(self.code_editor.filename, return_std=True)
159+
pylint_stdout = pylint_stdout.getvalue()
160+
logger.info(pylint_stdout)
161+
162+
new_code = self.linter.execute_task(
163+
source_code=self.code_editor.display_code(),
164+
stdout=pylint_stdout,
165+
)
166+
logger.warn("Linted code: %s", new_code)
167+
if new_code:
168+
self.code_editor.overwrite_code(new_code)
169+
170+
if not self.config.execute_code:
171+
return self.code_editor.display_code()
148172

149-
new_code = self.linter.execute_task(
150-
source_code=self.code_editor.display_code(),
151-
stdout=pylint_stdout,
173+
result = self.code_editor.run_code()
174+
175+
if "Succeeded" in result:
176+
break
177+
178+
elif self.config.code_sampling_strategy == PYLINT:
179+
coding_samples = []
180+
for i in range(self.config.code_sampling_strategy):
181+
self.coder.llm.set_parameter("temperature", i * self.config.sampling_temperature_multipler)
182+
logger.info("Coding sample: %s", i)
183+
new_code = self.coder.execute_task(
184+
source_code=self.code_editor.display_code(), objective=task, plan="\n".join(plan)
152185
)
153-
logger.warn("Linted code: %s", new_code)
154-
if new_code:
155-
self.code_editor.overwrite_code(new_code)
186+
coding_samples.append({"code": new_code})
187+
self.code_editor.overwrite_code(new_code)
188+
_trim_md(self.code_editor)
189+
190+
logger.info(self.code_editor.display_code())
191+
logger.info("Applying linter...")
156192

193+
(pylint_stdout, _) = lint.py_run(self.code_editor.filename, return_std=True)
194+
pylint_stdout = pylint_stdout.getvalue()
195+
split_1 = pylint_stdout.split("Your code has been rated at ")[0]
196+
linting_score_str = split_1.split("/")[0]
197+
score = float(linting_score_str)
198+
coding_samples[i]["score"] = score
199+
logger.info("Sample score: %s", score)
200+
201+
coding_samples.sort(key=lambda x: x["score"], reverse=True)
202+
highest_score = coding_samples[0]
203+
logger.info("Score of highest sample: %s", highest_score["score"])
204+
self.code_editor.overwrite_code(highest_score["code"])
157205
if not self.config.execute_code:
158206
return self.code_editor.display_code()
159207

160208
result = self.code_editor.run_code()
161209

162-
if "Succeeded" in result:
163-
break
210+
else:
211+
raise ValueError("Invalid Sampling Strategy")
212+
164213

165214
logger.info("Finished generating code!")
166215

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def read(fname):
1414

1515
setup(
1616
name="code_it",
17-
version="0.3.1",
17+
version="0.4.0",
1818
author="Paolo Rechia",
1919
author_email="paolorechia@gmail.com",
2020
maintainer="Paolo Rechia",

0 commit comments

Comments
 (0)