Skip to content

Commit d5f1c82

Browse files
committed
Merge remote-tracking branch 'origin/refactor_genrerate_answer_node'
2 parents 4a63e2a + 8fc4187 commit d5f1c82

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

scrapegraphai/nodes/generate_answer_node.py

+24-18
Original file line numberDiff line numberDiff line change
@@ -119,21 +119,27 @@ def execute(self, state):
119119
chain_name = f"chunk{i+1}"
120120
chains_dict[chain_name] = prompt | self.llm_model | output_parser
121121

122-
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
123-
map_chain = RunnableParallel(**chains_dict)
124-
# Chain
125-
answer_map = map_chain.invoke({"question": user_prompt})
126-
127-
# Merge the answers from the chunks
128-
merge_prompt = PromptTemplate(
129-
template=template_merge,
130-
input_variables=["context", "question"],
131-
partial_variables={"format_instructions": format_instructions},
132-
)
133-
merge_chain = merge_prompt | self.llm_model | output_parser
134-
answer = merge_chain.invoke(
135-
{"context": answer_map, "question": user_prompt})
136-
137-
# Update the state with the generated answer
138-
state.update({self.output[0]: answer})
139-
return state
122+
if len(chains_dict) > 1:
123+
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
124+
map_chain = RunnableParallel(**chains_dict)
125+
# Chain
126+
answer_map = map_chain.invoke({"question": user_prompt})
127+
128+
# Merge the answers from the chunks
129+
merge_prompt = PromptTemplate(
130+
template=template_merge,
131+
input_variables=["context", "question"],
132+
partial_variables={"format_instructions": format_instructions},
133+
)
134+
merge_chain = merge_prompt | self.llm_model | output_parser
135+
answer = merge_chain.invoke(
136+
{"context": answer_map, "question": user_prompt})
137+
138+
# Update the state with the generated answer
139+
state.update({self.output[0]: answer})
140+
return state
141+
142+
else:
143+
# Update the state with the generated answer
144+
state.update({self.output[0]: chains_dict})
145+
return state

0 commit comments

Comments
 (0)