Skip to content

Instantly share code, notes, and snippets.

@dranger003
Last active January 10, 2025 22:26
Show Gist options
  • Select an option

  • Save dranger003/5f5daa9c80b4193f180b93f71399f817 to your computer and use it in GitHub Desktop.

Select an option

Save dranger003/5f5daa9c80b4193f180b93f71399f817 to your computer and use it in GitHub Desktop.
A vLLM tool parser plugin for Command-R7B that handles function calling through action blocks.
# ** HOW TO USE **
# python -m vllm.entrypoints.openai.api_server \
# --pipeline-parallel-size "$GPU_COUNT" \
# --api-key "$API_KEY" \
# --model CohereForAI/c4ai-command-r7b-12-2024 \
# --chat-template c4ai-command-r7b-12-2024-tool_use.jinja \
# --chat-template-content-format string \
# --enable-auto-tool-choice \
# --tool-parser-plugin vllm-tool-parser-plugin-command-r7b.py \
# --tool-call-parser command-r7b
import json
import re
from typing import Dict, List, Optional, Sequence, Union, Any
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.tool_parsers.utils import (
partial_json_loads,
is_complete_json,
)
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ExtractedToolCallInformation,
DeltaMessage,
DeltaToolCall,
DeltaFunctionCall,
ToolCall,
FunctionCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module(["command-r7b"])
class CommandR7BToolParser(ToolParser):
"""
A tool parser for Command-R7B that handles function calling through action blocks.
This parser:
- Removes response wrapper tokens (START_RESPONSE/END_RESPONSE)
- Optionally removes thinking plans (START_THINKING/END_THINKING)
- Looks for tool calls between START_ACTION and END_ACTION tokens
- Handles both streaming and non-streaming parsing modes
- Supports both single tool calls and arrays of tool calls
Expected JSON format within action blocks:
{
"tool_call_id": "string",
"tool_name": "string",
"parameters": {
// Tool-specific parameters
}
}
"""
START_ACTION_TOKEN = "<|START_ACTION|>"
END_ACTION_TOKEN = "<|END_ACTION|>"
START_RESPONSE_TOKEN = "<|START_RESPONSE|>"
END_RESPONSE_TOKEN = "<|END_RESPONSE|>"
START_THINKING_TOKEN = "<|START_THINKING|>"
END_THINKING_TOKEN = "<|END_THINKING|>"
def __init__(self, tokenizer: AnyTokenizer, remove_thinking: bool = True):
"""
Initialize the parser with necessary state tracking and token validation.
Args:
tokenizer: The tokenizer to use
remove_thinking: Whether to remove thinking plans before tool calls
"""
super().__init__(tokenizer)
self.remove_thinking = remove_thinking
# Initialize state tracking
self.reset_state()
# Validate tokens exist in vocabulary
token_pairs = [
(self.START_ACTION_TOKEN, self.END_ACTION_TOKEN),
(self.START_RESPONSE_TOKEN, self.END_RESPONSE_TOKEN),
(self.START_THINKING_TOKEN, self.END_THINKING_TOKEN),
]
self.token_ids = {}
for start_token, end_token in token_pairs:
start_id = self.vocab.get(start_token)
end_id = self.vocab.get(end_token)
if None in (start_id, end_id):
raise RuntimeError(
f"Command-R7B parser could not locate {start_token} "
f"or {end_token} in tokenizer vocabulary"
)
self.token_ids[start_token] = start_id
self.token_ids[end_token] = end_id
# Buffer for partial JSON parsing
self.partial_json_buffer = ""
self.last_complete_json = None
def reset_state(self) -> None:
"""Reset parser state between requests."""
self.in_action_block = False
self.in_thinking_block = False
self.current_block_token_ids: List[int] = []
self.tool_call_index = 0
self.current_block_content = ""
self.available_tools: Dict[str, Dict] = {}
self.partial_json_buffer = ""
self.last_complete_json = None
self.current_tool_name_sent = False
self.streamed_args_for_tool: List[str] = []
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
"""Prepare for parsing by storing available tools and resetting state."""
self.reset_state()
if request.tools:
self.available_tools = {
tool.function.name: tool.function.model_dump() for tool in request.tools
}
return request
def clean_response_text(self, text: str) -> str:
"""Remove response wrapper tokens from text."""
if self.START_RESPONSE_TOKEN in text and self.END_RESPONSE_TOKEN in text:
parts = text.split(self.START_RESPONSE_TOKEN, 1)
if len(parts) > 1:
text = parts[1]
if self.END_RESPONSE_TOKEN in text:
text = text.split(self.END_RESPONSE_TOKEN, 1)[0]
return text.strip()
def clean_thinking_text(self, text: str) -> str:
"""Remove thinking plan sections if configured to do so."""
if not self.remove_thinking:
return text
while self.START_THINKING_TOKEN in text and self.END_THINKING_TOKEN in text:
start_idx = text.find(self.START_THINKING_TOKEN)
end_idx = text.find(self.END_THINKING_TOKEN, start_idx) + len(
self.END_THINKING_TOKEN
)
text = text[:start_idx] + text[end_idx:]
return text.strip()
def validate_tool_call(self, tool_name: str, parameters: Dict) -> bool:
"""Validate a tool call against available tools."""
if not self.available_tools:
return True # No tools specified in request
if tool_name not in self.available_tools:
logger.warning(f"Tool '{tool_name}' not found in available tools")
return False
# Check parameters format
if not isinstance(parameters, dict):
logger.warning(f"Parameters must be a dictionary for tool '{tool_name}'")
return False
# Handle both "parameters" and "arguments" fields like InternLM2
if "arguments" in parameters and "parameters" not in parameters:
parameters["parameters"] = parameters.pop("arguments")
# Validate unicode content
try:
json.dumps(parameters, ensure_ascii=False)
except UnicodeEncodeError:
logger.warning(f"Invalid Unicode in parameters for tool '{tool_name}'")
return False
return True
def parse_action_block(
self, block_text: str, streaming: bool = False
) -> List[Union[ToolCall, DeltaToolCall]]:
"""
Parse a complete action block into tool calls.
Args:
block_text: The text to parse
streaming: Whether this is being called in streaming mode
"""
try:
parsed = json.loads(block_text)
if not isinstance(parsed, list):
parsed = [parsed]
tool_calls = []
for call_dict in parsed:
# Extract and validate required fields
tool_call_id = call_dict.get("tool_call_id")
tool_name = call_dict.get("tool_name")
parameters = call_dict.get("parameters", {})
if not tool_name:
logger.warning("Tool call missing required 'tool_name' field")
continue
if not self.validate_tool_call(tool_name, parameters):
continue
# Create the tool call
function_call = FunctionCall(
name=tool_name, arguments=json.dumps(parameters)
)
tool_call = ToolCall(
id=f"chatcmpl-tool-{random_uuid()}",
type="function",
function=function_call,
)
tool_calls.append(tool_call)
return tool_calls
except json.JSONDecodeError as e:
logger.error(f"Failed to parse action block JSON: {e}")
return []
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Extract tool calls from complete model output.
Handles complete responses by:
1. Removing response wrapper tokens
2. Optionally removing thinking plans
3. Finding all action blocks and parsing their content
"""
# Clean the output text
model_output = self.clean_response_text(model_output)
model_output = self.clean_thinking_text(model_output)
# Regex for capturing everything between action markers
action_pattern = f"{self.START_ACTION_TOKEN}(.*?){self.END_ACTION_TOKEN}"
action_blocks = re.findall(action_pattern, model_output, re.DOTALL)
all_tool_calls = []
for block in action_blocks:
block = block.strip()
if not block:
continue
tool_calls = self.parse_action_block(block)
all_tool_calls.extend(tool_calls)
# Extract content before the first action block
content = None
if self.START_ACTION_TOKEN in model_output:
content = model_output.split(self.START_ACTION_TOKEN)[0]
if content.strip():
content = content.strip()
else:
content = None
else:
content = model_output
return ExtractedToolCallInformation(
tools_called=bool(all_tool_calls),
tool_calls=all_tool_calls,
content=content,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Optional[DeltaMessage]:
"""
Handle streaming extraction of tool calls.
Accumulates tokens within action blocks and parses them when complete.
Returns appropriate streaming deltas for tool calls and content.
"""
# Check for response tokens to remove
if self.token_ids[self.START_RESPONSE_TOKEN] in delta_token_ids:
return None
if self.token_ids[self.END_RESPONSE_TOKEN] in delta_token_ids:
return None
# Check for thinking tokens to potentially remove
if self.remove_thinking:
if self.token_ids[self.START_THINKING_TOKEN] in delta_token_ids:
self.in_thinking_block = True
return None
if self.token_ids[self.END_THINKING_TOKEN] in delta_token_ids:
self.in_thinking_block = False
return None
if self.in_thinking_block:
return None
new_tool_calls = []
try:
for tid in delta_token_ids:
if tid == self.token_ids[self.START_ACTION_TOKEN]:
# Entering an action block
self.in_action_block = True
self.current_block_token_ids = []
self.current_block_content = ""
elif tid == self.token_ids[self.END_ACTION_TOKEN]:
# Process complete action block
if self.current_block_token_ids:
block_text = self.model_tokenizer.decode(
self.current_block_token_ids
)
block_text = block_text.strip()
self.partial_json_buffer += block_text
if block_text:
try:
# Try partial JSON parsing first
flags = (
Allow.ALL
if self.current_tool_name_sent
else Allow.ALL & ~Allow.STR
)
try:
parsed_obj, end_idx = partial_json_loads(
self.partial_json_buffer, flags
)
is_complete = is_complete_json(
self.partial_json_buffer[:end_idx]
)
if isinstance(parsed_obj, (dict, list)):
parsed = (
[parsed_obj]
if isinstance(parsed_obj, dict)
else parsed_obj
)
# If we have a complete JSON object, update our last known good state
if is_complete:
self.last_complete_json = parsed
else:
# Fall back to regular JSON parsing if partial parse gave unexpected type
parsed = json.loads(block_text)
if not isinstance(parsed, list):
parsed = [parsed]
except (
partial_json_parser.core.exceptions.MalformedJSON
):
# If partial parsing fails, try regular JSON parse
parsed = json.loads(block_text)
if not isinstance(parsed, list):
parsed = [parsed]
for call_dict in parsed:
tool_name = call_dict.get("tool_name")
parameters = call_dict.get("parameters", {})
if not tool_name or not self.validate_tool_call(
tool_name, parameters
):
continue
delta_fc = DeltaFunctionCall(
name=tool_name, arguments=json.dumps(parameters)
).model_dump(exclude_none=True)
dtc = DeltaToolCall(
index=self.tool_call_index,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=delta_fc,
)
new_tool_calls.append(dtc)
self.tool_call_index += 1
except json.JSONDecodeError as e:
logger.error(f"JSON parsing error in streaming: {e}")
# Reset block state
self.in_action_block = False
self.current_block_token_ids = []
self.current_block_content = ""
else:
# Accumulate tokens if in action block
if self.in_action_block:
self.current_block_token_ids.append(tid)
else:
# Regular content outside action block
return DeltaMessage(content=delta_text)
# Return tool calls if any were found
if new_tool_calls:
return DeltaMessage(tool_calls=new_tool_calls)
# Return None to skip this chunk if we're accumulating an action block
if self.in_action_block:
return None
# Otherwise return regular content
return DeltaMessage(content=delta_text)
except Exception as e:
logger.exception("Error in streaming tool call extraction")
# Reset state and skip chunk on error
self.reset_state()
return None
@dranger003
Copy link
Author

C4AI Command R7B emits wrapper response tokens which are removed by this plugin. Also, when performing tool calls the model emits its plan ahead of the calls using wrapper thinking tokens. To keep the thinking plan in the model's response, change remove_thinking: bool = True.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment