from typing import Annotated, Optional, Dict, List
from typing_extensions import TypedDict
from langgraph.graph import StateGraph
class State(TypedDict):
sql_sp: Optional[str] = None
java_code: Optional[str] = None
sp_analysis: Optional[str] = None
java_analysis: Optional[str] = None
user_story: Optional[str] = None
generated_java_code: Optional[str] = None
generated_test_code: Optional[str] = None
workflow = StateGraph(State)
def setup_input_node(state: State):
state["sql_sp"] = sql_sp
state["java_code"] = java_code
return state
def sp_analysis_node(state: State):
analyze_sp_prompt = hub.pull("peerislands/demo_step_1")
node_1 = analyze_sp_prompt | llm | output_parser
node_1_result = node_1.invoke({"stored_procedure": sql_sp})
state["sp_analysis"] = mdformat.text(node_1_result)
return state
def java_analysis_node(state: State):
analyze_java_prompt = hub.pull("peerislands/demo_step_2")
node_2 = analyze_java_prompt | llm | output_parser
node_2_result = node_2.invoke({"java_code": java_code})
state["java_analysis"] = mdformat.text(node_2_result)
return state
def user_story_node(state: State):
user_story_prompt = hub.pull("peerislands/demo_step_3")
node_3 = user_story_prompt | llm | output_parser
node_3_result = node_3.invoke({
"stored_procedure_analysis": state["sp_analysis"],
"java_code_analysis": state["java_analysis"]
})
state["user_story"] = mdformat.text(node_3_result)
return state
def generate_java_code_node(state: State):
generate_java_code_prompt = hub.pull("peerislands/demo_step_4")
node_4 = generate_java_code_prompt | llm | output_parser
node_4_result = node_4.invoke({"user_story": state["user_story"]})
state["generated_java_code"] = mdformat.text(node_4_result)
return state
def generate_test_code_node(state: State):
generate_test_code_prompt = hub.pull("peerislands/demo_step_5")
node_5 = generate_test_code_prompt | llm | output_parser
node_5_result = node_5.invoke({"java_code": state["generated_java_code"]})
state["generated_test_code"] = mdformat.text(node_5_result)
return state
workflow.add_node("setup_input_node", setup_input_node)
workflow.add_node("sp_analysis_node", sp_analysis_node)
workflow.add_node("java_analysis_node", java_analysis_node)
workflow.add_node("user_story_node", user_story_node)
workflow.add_node("generate_java_code_node", generate_java_code_node)
workflow.add_node("generate_test_code_node", generate_test_code_node)
workflow.set_entry_point("setup_input_node")
workflow.add_edge("setup_input_node", "sp_analysis_node")
workflow.add_edge("sp_analysis_node", "java_analysis_node")
workflow.add_edge("java_analysis_node", "user_story_node")
workflow.add_edge("user_story_node", "generate_java_code_node")
workflow.add_edge("generate_java_code_node", "generate_test_code_node")
workflow.set_finish_point("generate_test_code_node")
app = workflow.compile()
from IPython.display import Image, display
try:
display(Image(app.get_graph().draw_mermaid_png()))
except:
pass
result = app.invoke({})