2
2
Module for creating the base graphs
3
3
"""
4
4
import time
5
+ import warnings
5
6
from langchain_community .callbacks import get_openai_callback
6
7
7
8
@@ -10,31 +11,37 @@ class BaseGraph:
10
11
BaseGraph manages the execution flow of a graph composed of interconnected nodes.
11
12
12
13
Attributes:
13
- nodes (dict ): A dictionary mapping each node's name to its corresponding node instance.
14
- edges (dict ): A dictionary representing the directed edges of the graph where each
14
+ nodes (list ): A dictionary mapping each node's name to its corresponding node instance.
15
+ edges (list ): A dictionary representing the directed edges of the graph where each
15
16
key-value pair corresponds to the from-node and to-node relationship.
16
17
entry_point (str): The name of the entry point node from which the graph execution begins.
17
18
18
19
Methods:
19
- execute(initial_state): Executes the graph's nodes starting from the entry point and
20
+ execute(initial_state): Executes the graph's nodes starting from the entry point and
20
21
traverses the graph based on the provided initial state.
21
22
22
23
Args:
23
24
nodes (iterable): An iterable of node instances that will be part of the graph.
24
- edges (iterable): An iterable of tuples where each tuple represents a directed edge
25
+ edges (iterable): An iterable of tuples where each tuple represents a directed edge
25
26
in the graph, defined by a pair of nodes (from_node, to_node).
26
27
entry_point (BaseNode): The node instance that represents the entry point of the graph.
27
28
"""
28
29
29
- def __init__ (self , nodes : dict , edges : dict , entry_point : str ):
30
+ def __init__ (self , nodes : list , edges : list , entry_point : str ):
30
31
"""
31
32
Initializes the graph with nodes, edges, and the entry point.
32
33
"""
33
- self .nodes = {node .node_name : node for node in nodes }
34
- self .edges = self ._create_edges (edges )
34
+
35
+ self .nodes = nodes
36
+ self .edges = self ._create_edges ({e for e in edges })
35
37
self .entry_point = entry_point .node_name
36
38
37
- def _create_edges (self , edges : dict ) -> dict :
39
+ if nodes [0 ].node_name != entry_point .node_name :
40
+ # raise a warning if the entry point is not the first node in the list
41
+ warnings .warn (
42
+ "Careful! The entry point node is different from the first node if the graph." )
43
+
44
+ def _create_edges (self , edges : list ) -> dict :
38
45
"""
39
46
Helper method to create a dictionary of edges from the given iterable of tuples.
40
47
@@ -51,8 +58,8 @@ def _create_edges(self, edges: dict) -> dict:
51
58
52
59
def execute (self , initial_state : dict ) -> dict :
53
60
"""
54
- Executes the graph by traversing nodes starting from the entry point. The execution
55
- follows the edges based on the result of each node's execution and continues until
61
+ Executes the graph by traversing nodes starting from the entry point. The execution
62
+ follows the edges based on the result of each node's execution and continues until
56
63
it reaches a node with no outgoing edges.
57
64
58
65
Args:
@@ -61,7 +68,8 @@ def execute(self, initial_state: dict) -> dict:
61
68
Returns:
62
69
dict: The state after execution has completed, which may have been altered by the nodes.
63
70
"""
64
- current_node_name = self .entry_point
71
+ print (self .nodes )
72
+ current_node_name = self .nodes [0 ]
65
73
state = initial_state
66
74
67
75
# variables for tracking execution info
@@ -75,10 +83,10 @@ def execute(self, initial_state: dict) -> dict:
75
83
"total_cost_USD" : 0.0 ,
76
84
}
77
85
78
- while current_node_name is not None :
86
+ for index in self . nodes :
79
87
80
88
curr_time = time .time ()
81
- current_node = self . nodes [ current_node_name ]
89
+ current_node = index
82
90
83
91
with get_openai_callback () as cb :
84
92
result = current_node .execute (state )
0 commit comments