Skip to content

Commit 4f1d28d

Browse files
authored
Merge pull request #81 from VinciGit00/refactoring_engine
Refactoring engine -> from set to list
2 parents 7c8dbb8 + cd8d3e7 commit 4f1d28d

File tree

7 files changed

+48
-36
lines changed

7 files changed

+48
-36
lines changed

examples/openai/custom_graph_openai.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,19 @@
6262
# ************************************************
6363

6464
graph = BaseGraph(
65-
nodes={
65+
nodes=[
6666
robot_node,
6767
fetch_node,
6868
parse_node,
6969
rag_node,
7070
generate_answer_node,
71-
},
72-
edges={
71+
],
72+
edges=[
7373
(robot_node, fetch_node),
7474
(fetch_node, parse_node),
7575
(parse_node, rag_node),
7676
(rag_node, generate_answer_node)
77-
},
77+
],
7878
entry_point=robot_node
7979
)
8080

manual deployment/commit_and_push_with_tests.sh

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ pylint pylint scrapegraphai/**/*.py scrapegraphai/*.py tests/**/*.py
1313

1414
cd tests
1515

16+
poetry install
17+
1618
# Run pytest
1719
if ! pytest; then
1820
echo "Pytest failed. Aborting commit and push."

scrapegraphai/graphs/base_graph.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Module for creating the base graphs
33
"""
44
import time
5+
import warnings
56
from langchain_community.callbacks import get_openai_callback
67

78

@@ -10,31 +11,37 @@ class BaseGraph:
1011
BaseGraph manages the execution flow of a graph composed of interconnected nodes.
1112
1213
Attributes:
13-
nodes (dict): A dictionary mapping each node's name to its corresponding node instance.
14-
edges (dict): A dictionary representing the directed edges of the graph where each
14+
nodes (list): A dictionary mapping each node's name to its corresponding node instance.
15+
edges (list): A dictionary representing the directed edges of the graph where each
1516
key-value pair corresponds to the from-node and to-node relationship.
1617
entry_point (str): The name of the entry point node from which the graph execution begins.
1718
1819
Methods:
19-
execute(initial_state): Executes the graph's nodes starting from the entry point and
20+
execute(initial_state): Executes the graph's nodes starting from the entry point and
2021
traverses the graph based on the provided initial state.
2122
2223
Args:
2324
nodes (iterable): An iterable of node instances that will be part of the graph.
24-
edges (iterable): An iterable of tuples where each tuple represents a directed edge
25+
edges (iterable): An iterable of tuples where each tuple represents a directed edge
2526
in the graph, defined by a pair of nodes (from_node, to_node).
2627
entry_point (BaseNode): The node instance that represents the entry point of the graph.
2728
"""
2829

29-
def __init__(self, nodes: dict, edges: dict, entry_point: str):
30+
def __init__(self, nodes: list, edges: list, entry_point: str):
3031
"""
3132
Initializes the graph with nodes, edges, and the entry point.
3233
"""
33-
self.nodes = {node.node_name: node for node in nodes}
34-
self.edges = self._create_edges(edges)
34+
35+
self.nodes = nodes
36+
self.edges = self._create_edges({e for e in edges})
3537
self.entry_point = entry_point.node_name
3638

37-
def _create_edges(self, edges: dict) -> dict:
39+
if nodes[0].node_name != entry_point.node_name:
40+
# raise a warning if the entry point is not the first node in the list
41+
warnings.warn(
42+
"Careful! The entry point node is different from the first node if the graph.")
43+
44+
def _create_edges(self, edges: list) -> dict:
3845
"""
3946
Helper method to create a dictionary of edges from the given iterable of tuples.
4047
@@ -51,8 +58,8 @@ def _create_edges(self, edges: dict) -> dict:
5158

5259
def execute(self, initial_state: dict) -> dict:
5360
"""
54-
Executes the graph by traversing nodes starting from the entry point. The execution
55-
follows the edges based on the result of each node's execution and continues until
61+
Executes the graph by traversing nodes starting from the entry point. The execution
62+
follows the edges based on the result of each node's execution and continues until
5663
it reaches a node with no outgoing edges.
5764
5865
Args:
@@ -61,7 +68,8 @@ def execute(self, initial_state: dict) -> dict:
6168
Returns:
6269
dict: The state after execution has completed, which may have been altered by the nodes.
6370
"""
64-
current_node_name = self.entry_point
71+
print(self.nodes)
72+
current_node_name = self.nodes[0]
6573
state = initial_state
6674

6775
# variables for tracking execution info
@@ -75,10 +83,10 @@ def execute(self, initial_state: dict) -> dict:
7583
"total_cost_USD": 0.0,
7684
}
7785

78-
while current_node_name is not None:
86+
for index in self.nodes:
7987

8088
curr_time = time.time()
81-
current_node = self.nodes[current_node_name]
89+
current_node = index
8290

8391
with get_openai_callback() as cb:
8492
result = current_node.execute(state)

scrapegraphai/graphs/script_creator_graph.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
"""
22
Module for creating the smart scraper
33
"""
44
from .base_graph import BaseGraph
@@ -57,17 +57,17 @@ def _create_graph(self):
5757
)
5858

5959
return BaseGraph(
60-
nodes={
60+
nodes=[
6161
fetch_node,
6262
parse_node,
6363
rag_node,
6464
generate_scraper_node,
65-
},
66-
edges={
65+
],
66+
edges=[
6767
(fetch_node, parse_node),
6868
(parse_node, rag_node),
6969
(rag_node, generate_scraper_node)
70-
},
70+
],
7171
entry_point=fetch_node
7272
)
7373

scrapegraphai/graphs/search_graph.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from .abstract_graph import AbstractGraph
1313

14+
1415
class SearchGraph(AbstractGraph):
1516
"""
1617
Module for searching info on the internet
@@ -49,19 +50,19 @@ def _create_graph(self):
4950
)
5051

5152
return BaseGraph(
52-
nodes={
53+
nodes=[
5354
search_internet_node,
5455
fetch_node,
5556
parse_node,
5657
rag_node,
5758
generate_answer_node,
58-
},
59-
edges={
59+
],
60+
edges=[
6061
(search_internet_node, fetch_node),
6162
(fetch_node, parse_node),
6263
(parse_node, rag_node),
6364
(rag_node, generate_answer_node)
64-
},
65+
],
6566
entry_point=search_internet_node
6667
)
6768

scrapegraphai/graphs/smart_scraper_graph.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
"""
22
Module for creating the smart scraper
33
"""
44
from .base_graph import BaseGraph
@@ -10,6 +10,7 @@
1010
)
1111
from .abstract_graph import AbstractGraph
1212

13+
1314
class SmartScraperGraph(AbstractGraph):
1415
"""
1516
SmartScraper is a comprehensive web scraping tool that automates the process of extracting
@@ -52,25 +53,25 @@ def _create_graph(self):
5253
)
5354

5455
return BaseGraph(
55-
nodes={
56+
nodes=[
5657
fetch_node,
5758
parse_node,
5859
rag_node,
5960
generate_answer_node,
60-
},
61-
edges={
61+
],
62+
edges=[
6263
(fetch_node, parse_node),
6364
(parse_node, rag_node),
6465
(rag_node, generate_answer_node)
65-
},
66+
],
6667
entry_point=fetch_node
6768
)
6869

6970
def run(self) -> str:
7071
"""
7172
Executes the web scraping process and returns the answer to the prompt.
7273
"""
73-
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
74+
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
7475
self.final_state, self.execution_info = self.graph.execute(inputs)
7576

7677
return self.final_state.get("answer", "No answer found.")

scrapegraphai/graphs/speech_graph.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,19 @@ def _create_graph(self):
6262
)
6363

6464
return BaseGraph(
65-
nodes={
65+
nodes=[
6666
fetch_node,
6767
parse_node,
6868
rag_node,
6969
generate_answer_node,
7070
text_to_speech_node
71-
},
72-
edges={
71+
],
72+
edges=[
7373
(fetch_node, parse_node),
7474
(parse_node, rag_node),
7575
(rag_node, generate_answer_node),
7676
(generate_answer_node, text_to_speech_node)
77-
},
77+
],
7878
entry_point=fetch_node
7979
)
8080

0 commit comments

Comments
 (0)