diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 88543030..65dc7b85 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -320,7 +320,11 @@ class ModelFacade: tool_schemas = None tool_call_turns = 0 total_tool_calls = 0 - curr_num_correction_steps = 0 + # Counts parse attempts within the current conversation (initial attempt + corrections). + # The first parse is attempt 1, so `parse_attempts <= max_correction_steps` permits exactly + # `max_correction_steps` corrections after the initial attempt before falling through to + # restart-or-raise. Reset to 0 on each conversation restart. + parse_attempts = 0 curr_num_restarts = 0 mcp_facade = self._get_mcp_facade(tool_alias) @@ -367,7 +371,7 @@ class ModelFacade: response = (completion_response.message.content or "").strip() reasoning_trace = completion_response.message.reasoning_content messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) - curr_num_correction_steps += 1 + parse_attempts += 1 try: output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below @@ -379,12 +383,12 @@ class ModelFacade: exc, ) from exc - if curr_num_correction_steps <= max_correction_steps: + if parse_attempts <= max_correction_steps: # Add user message with error for correction messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc)))) elif curr_num_restarts < max_conversation_restarts: - curr_num_correction_steps = 0 + parse_attempts = 0 curr_num_restarts += 1 messages = deepcopy(restart_checkpoint) tool_call_turns = checkpoint_tool_call_turns @@ -425,7 +429,8 @@ class ModelFacade: tool_schemas = None tool_call_turns = 0 total_tool_calls = 0 - curr_num_correction_steps = 0 + # See `generate` for a description of the parse-attempts counter semantics. + parse_attempts = 0 curr_num_restarts = 0 mcp_facade = self._get_mcp_facade(tool_alias) @@ -469,7 +474,7 @@ class ModelFacade: response = (completion_response.message.content or "").strip() reasoning_trace = completion_response.message.reasoning_content messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) - curr_num_correction_steps += 1 + parse_attempts += 1 try: output_obj = parser(response) @@ -481,11 +486,11 @@ class ModelFacade: exc, ) from exc - if curr_num_correction_steps <= max_correction_steps: + if parse_attempts <= max_correction_steps: messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc)))) elif curr_num_restarts < max_conversation_restarts: - curr_num_correction_steps = 0 + parse_attempts = 0 curr_num_restarts += 1 messages = deepcopy(restart_checkpoint) tool_call_turns = checkpoint_tool_call_turns