复制 #@title FunctionGemma 解析代码(可扩展)
import re
def extract_tool_calls(text):
def cast(v):
try: return int(v)
except:
try: return float(v)
except: return {'true': True, 'false': False}.get(v.lower(), v.strip("'\""))
return [{
"name": name,
"arguments": {
k: cast((v1 or v2).strip())
for k, v1, v2 in re.findall(r"(\w+):(?:<escape>(.*?)<escape>|([^,}]*))", args)
}
} for name, args in re.findall(r"<start_function_call>call:(\w+)\{(.*?)\}<end_function_call>", text, re.DOTALL)]
def process_tool_calls(output, messages):
calls = extract_tool_calls(output)
if not calls: return messages
messages.append({
"role": "assistant",
"tool_calls": [{"type": "function", "function": call} for call in calls]
})
results = [
{"name": c['name'], "response": FUNCTION_MAPPING[c['name']](**c['arguments'])}
for c in calls
]
messages.append({ "role": "tool", "content": results })
has_tool_calls = False
def _do_inference(model, messages, max_new_tokens = 128):
inputs = tokenizer.apply_chat_template(
messages, tools = TOOLS, add_generation_prompt = True, return_dict = True, return_tensors = "pt",
)
output = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens = False)
out = model.generate(**inputs.to(model.device), max_new_tokens = max_new_tokens,
top_p = 0.95, top_k = 64, temperature = 1.0,)
generated_tokens = out[0][len(inputs["input_ids"][0]):]
return tokenizer.decode(generated_tokens, skip_special_tokens = True)
def do_inference(model, messages, print_assistant = True, max_new_tokens = 128):
output = _do_inference(model, messages, max_new_tokens = max_new_tokens)
messages = process_tool_calls(output, messages)
if messages[-1]["role"] == "tool":
output = _do_inference(model, messages, max_new_tokens = max_new_tokens)
messages.append({"role": "assistant", "content": output})
if print_assistant: print(output)
has_tool_calls = False