|
from typing import Annotated, Optional, Dict, List |
|
from typing_extensions import TypedDict |
|
from langgraph.graph import StateGraph |
|
|
|
class State(TypedDict): |
|
sql_sp: Optional[str] = None |
|
java_code: Optional[str] = None |
|
sp_analysis: Optional[str] = None |
|
java_analysis: Optional[str] = None |
|
user_story: Optional[str] = None |
|
generated_java_code: Optional[str] = None |
|
generated_test_code: Optional[str] = None |
|
|
|
workflow = StateGraph(State) |
|
|
|
def setup_input_node(state: State): |
|
state["sql_sp"] = sql_sp |
|
state["java_code"] = java_code |
|
return state |
|
|
|
def sp_analysis_node(state: State): |
|
analyze_sp_prompt = hub.pull("peerislands/demo_step_1") |
|
node_1 = analyze_sp_prompt | llm | output_parser |
|
node_1_result = node_1.invoke({"stored_procedure": sql_sp}) |
|
state["sp_analysis"] = mdformat.text(node_1_result) |
|
return state |
|
|
|
def java_analysis_node(state: State): |
|
analyze_java_prompt = hub.pull("peerislands/demo_step_2") |
|
node_2 = analyze_java_prompt | llm | output_parser |
|
node_2_result = node_2.invoke({"java_code": java_code}) |
|
state["java_analysis"] = mdformat.text(node_2_result) |
|
return state |
|
|
|
def user_story_node(state: State): |
|
user_story_prompt = hub.pull("peerislands/demo_step_3") |
|
node_3 = user_story_prompt | llm | output_parser |
|
node_3_result = node_3.invoke({ |
|
"stored_procedure_analysis": state["sp_analysis"], |
|
"java_code_analysis": state["java_analysis"] |
|
}) |
|
state["user_story"] = mdformat.text(node_3_result) |
|
return state |
|
|
|
def generate_java_code_node(state: State): |
|
generate_java_code_prompt = hub.pull("peerislands/demo_step_4") |
|
node_4 = generate_java_code_prompt | llm | output_parser |
|
node_4_result = node_4.invoke({"user_story": state["user_story"]}) |
|
state["generated_java_code"] = mdformat.text(node_4_result) |
|
return state |
|
|
|
def generate_test_code_node(state: State): |
|
generate_test_code_prompt = hub.pull("peerislands/demo_step_5") |
|
node_5 = generate_test_code_prompt | llm | output_parser |
|
node_5_result = node_5.invoke({"java_code": state["generated_java_code"]}) |
|
state["generated_test_code"] = mdformat.text(node_5_result) |
|
return state |
|
|
|
workflow.add_node("setup_input_node", setup_input_node) |
|
workflow.add_node("sp_analysis_node", sp_analysis_node) |
|
workflow.add_node("java_analysis_node", java_analysis_node) |
|
workflow.add_node("user_story_node", user_story_node) |
|
workflow.add_node("generate_java_code_node", generate_java_code_node) |
|
workflow.add_node("generate_test_code_node", generate_test_code_node) |
|
|
|
workflow.set_entry_point("setup_input_node") |
|
workflow.add_edge("setup_input_node", "sp_analysis_node") |
|
workflow.add_edge("sp_analysis_node", "java_analysis_node") |
|
workflow.add_edge("java_analysis_node", "user_story_node") |
|
workflow.add_edge("user_story_node", "generate_java_code_node") |
|
workflow.add_edge("generate_java_code_node", "generate_test_code_node") |
|
workflow.set_finish_point("generate_test_code_node") |
|
|
|
app = workflow.compile() |
|
|
|
from IPython.display import Image, display |
|
try: |
|
display(Image(app.get_graph().draw_mermaid_png())) |
|
except: |
|
pass |
|
|
|
result = app.invoke({}) |