unsloth/007过滤think.py

648 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import json
import time
import threading
from typing import Dict, Any, Iterator, Optional
from functools import wraps
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
class OllamaException(Exception):
"""Ollama自定义异常类"""
pass
class OllamaServerError(OllamaException):
"""服务器错误异常"""
pass
class OllamaRateLimitError(OllamaException):
"""请求限制异常"""
pass
class OllamaClient:
def __init__(self, base_url: str = "http://localhost:11434", max_retries: int = 3,
retry_delay: float = 1.0, max_concurrent_requests: int = 5,
timeout: int = 30, enable_logging: bool = True):
"""
初始化Ollama客户端
Args:
base_url: Ollama服务的基础URL默认为本地11434端口
max_retries: 最大重试次数
retry_delay: 重试间隔(秒)
max_concurrent_requests: 最大并发请求数
timeout: 请求超时时间(秒)
enable_logging: 是否启用日志
"""
self.base_url = base_url
self.max_retries = max_retries
self.retry_delay = retry_delay
self.max_concurrent_requests = max_concurrent_requests
self.timeout = timeout
# 设置日志
if enable_logging:
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
else:
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.CRITICAL)
# 创建会话
self.session = requests.Session()
self.session.timeout = timeout
# 并发控制
self._request_lock = threading.Lock()
self._active_requests = 0
self._executor = ThreadPoolExecutor(max_workers=max_concurrent_requests)
# 请求统计
self._request_count = 0
self._success_count = 0
self._error_count = 0
self._last_request_time = 0
def _wait_for_slot(self):
"""等待可用的请求槽位"""
while self._active_requests >= self.max_concurrent_requests:
time.sleep(0.1)
def _acquire_request_slot(self):
"""获取请求槽位"""
with self._request_lock:
self._wait_for_slot()
self._active_requests += 1
self._request_count += 1
def _release_request_slot(self):
"""释放请求槽位"""
with self._request_lock:
self._active_requests = max(0, self._active_requests - 1)
def _handle_response_error(self, response, attempt: int):
"""处理响应错误"""
if response.status_code == 500:
if attempt < self.max_retries:
wait_time = self.retry_delay * (2 ** attempt) # 指数退避
self.logger.warning(f"服务器错误500{attempt + 1}次重试,等待{wait_time:.1f}秒...")
time.sleep(wait_time)
return True # 需要重试
else:
raise OllamaServerError(f"服务器错误500已达到最大重试次数{self.max_retries}")
elif response.status_code == 429:
# 速率限制
if attempt < self.max_retries:
wait_time = self.retry_delay * (2 ** attempt)
self.logger.warning(f"请求过于频繁(429),第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
time.sleep(wait_time)
return True
else:
raise OllamaRateLimitError(f"请求过于频繁,已达到最大重试次数{self.max_retries}")
elif response.status_code == 503:
# 服务不可用
if attempt < self.max_retries:
wait_time = self.retry_delay * (2 ** attempt)
self.logger.warning(f"服务不可用(503),第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
time.sleep(wait_time)
return True
else:
raise OllamaServerError(f"服务不可用,已达到最大重试次数{self.max_retries}")
else:
# 其他HTTP错误
response.raise_for_status()
return False
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
"""发送请求的通用方法,包含重试逻辑"""
url = f"{self.base_url}{endpoint}"
for attempt in range(self.max_retries + 1):
try:
self._acquire_request_slot()
# 速率限制检查
current_time = time.time()
if current_time - self._last_request_time < 0.1: # 最小间隔100ms
time.sleep(0.1)
self._last_request_time = current_time
self.logger.debug(f"发送请求: {method} {url}, 尝试 {attempt + 1}/{self.max_retries + 1}")
if method.upper() == "GET":
response = self.session.get(url, **kwargs)
else:
response = self.session.post(url, **kwargs)
# 检查响应状态
if response.status_code == 200:
self._success_count += 1
return response
else:
# 处理错误响应
should_retry = self._handle_response_error(response, attempt)
if not should_retry:
break
except (requests.exceptions.ConnectTimeout,
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError) as e:
if attempt < self.max_retries:
wait_time = self.retry_delay * (2 ** attempt)
self.logger.warning(f"网络错误: {e}, 第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
time.sleep(wait_time)
else:
self._error_count += 1
raise OllamaException(f"网络连接失败: {e}")
except requests.exceptions.RequestException as e:
self._error_count += 1
raise OllamaException(f"请求失败: {e}")
finally:
self._release_request_slot()
# 如果所有重试都失败了
self._error_count += 1
raise OllamaException("所有重试都失败了")
def get_stats(self) -> Dict[str, Any]:
"""获取请求统计信息"""
return {
"total_requests": self._request_count,
"successful_requests": self._success_count,
"failed_requests": self._error_count,
"active_requests": self._active_requests,
"success_rate": self._success_count / max(self._request_count, 1) * 100
}
def list_models(self) -> Dict[str, Any]:
"""获取已安装的模型列表"""
try:
response = self._make_request("GET", "/api/tags")
return response.json()
except Exception as e:
self.logger.error(f"获取模型列表失败: {e}")
return {}
def generate(self, model: str, prompt: str, stream: bool = False,
no_thinking: bool = True, **kwargs) -> Any:
"""
生成文本
Args:
model: 模型名称
prompt: 输入提示
stream: 是否流式输出
no_thinking: 是否禁用思考过程,直接回答
**kwargs: 其他参数如temperature, top_p等
"""
# 如果启用no_thinking修改prompt以指示模型直接回答
if no_thinking:
prompt = self._format_no_thinking_prompt(prompt)
data = {
"model": model,
"prompt": prompt,
"stream": stream,
**kwargs
}
# 添加直接回答的系统参数
if no_thinking:
data.update({
"system": "你是一个直接、简洁的AI助手。请直接回答问题不要显示思考过程或分析步骤。",
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"repeat_penalty": kwargs.get("repeat_penalty", 1.1)
})
try:
response = self._make_request(
"POST",
"/api/generate",
json=data,
stream=stream
)
if stream:
return self._handle_stream_response(response, filter_thinking=no_thinking)
else:
result = response.json()
if no_thinking and result:
result = self._filter_thinking_from_response(result)
return result
except Exception as e:
self.logger.error(f"生成请求失败: {e}")
return None
def chat(self, model: str, messages: list, stream: bool = False,
no_thinking: bool = True, **kwargs) -> Any:
"""
对话模式
Args:
model: 模型名称
messages: 消息列表,格式为[{"role": "user", "content": "..."}]
stream: 是否流式输出
no_thinking: 是否禁用思考过程,直接回答
**kwargs: 其他参数
"""
# 如果启用no_thinking在消息中添加系统提示
if no_thinking:
messages = self._format_no_thinking_messages(messages)
data = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs
}
# 添加直接回答的参数
if no_thinking:
data.update({
"system": "你是一个直接、简洁的AI助手。请直接回答问题不要显示思考过程或分析步骤。",
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"repeat_penalty": kwargs.get("repeat_penalty", 1.1)
})
try:
response = self._make_request(
"POST",
"/api/chat",
json=data,
stream=stream
)
if stream:
return self._handle_stream_response(response, filter_thinking=no_thinking)
else:
result = response.json()
if no_thinking and result:
result = self._filter_thinking_from_chat_response(result)
return result
except Exception as e:
self.logger.error(f"对话请求失败: {e}")
return None
def _format_no_thinking_prompt(self, prompt: str) -> str:
"""格式化prompt以避免显示思考过程"""
return f"""请直接回答以下问题,不要显示思考过程、分析步骤或推理过程:
{prompt}
注意:请直接给出答案,不要包含"让我想想""分析一下""首先""然后"等思考过程的表述。"""
def _format_no_thinking_messages(self, messages: list) -> list:
"""格式化消息以避免显示思考过程"""
formatted_messages = []
# 添加系统消息
system_message = {
"role": "system",
"content": "你是一个直接、简洁的AI助手。请直接回答问题不要显示思考过程、分析步骤或推理过程。不要使用'让我想想''分析一下''首先''然后'等表述。"
}
formatted_messages.append(system_message)
# 添加原始消息
for message in messages:
if message.get("role") == "user":
content = message.get("content", "")
formatted_content = f"请直接回答:{content}"
formatted_messages.append({
"role": "user",
"content": formatted_content
})
else:
formatted_messages.append(message)
return formatted_messages
def _filter_thinking_from_response(self, response: dict) -> dict:
"""从响应中过滤掉思考过程"""
if "response" in response:
content = response["response"]
# 过滤常见的思考过程表述
filtered_content = self._filter_thinking_text(content)
response["response"] = filtered_content
return response
def _filter_thinking_from_chat_response(self, response: dict) -> dict:
"""从对话响应中过滤掉思考过程"""
if "message" in response and "content" in response["message"]:
content = response["message"]["content"]
# 过滤常见的思考过程表述
filtered_content = self._filter_thinking_text(content)
response["message"]["content"] = filtered_content
return response
def _filter_thinking_text(self, text: str) -> str:
"""过滤文本中的思考过程表述"""
# 常见的思考过程表述
thinking_patterns = [
r"让我想想[。,\n]*",
r"让我来分析一下[。,\n]*",
r"让我来思考一下[。,\n]*",
r"首先[,。]*让我[^。]*[。,\n]*",
r"我来分析一下[。,\n]*",
r"我需要思考一下[。,\n]*",
r"让我仔细考虑一下[。,\n]*",
r"这个问题需要[^。]*分析[。,\n]*",
r"思考:[^\n]*\n*",
r"分析:[^\n]*\n*",
r"<thinking>.*?</thinking>",
r"\*思考\*[^\n]*\n*",
r"\*分析\*[^\n]*\n*"
]
import re
filtered_text = text
for pattern in thinking_patterns:
filtered_text = re.sub(pattern, "", filtered_text, flags=re.DOTALL | re.IGNORECASE)
# 清理多余的空行和空格
filtered_text = re.sub(r'\n\s*\n', '\n\n', filtered_text)
filtered_text = filtered_text.strip()
return filtered_text
def _handle_stream_response(self, response, filter_thinking: bool = False) -> Iterator[Dict[str, Any]]:
"""处理流式响应"""
for line in response.iter_lines():
if line:
try:
chunk = json.loads(line.decode('utf-8'))
if filter_thinking and "response" in chunk:
# 实时过滤思考过程
content = chunk["response"]
if not self._is_thinking_content(content):
yield chunk
else:
yield chunk
except json.JSONDecodeError:
continue
def _is_thinking_content(self, content: str) -> bool:
"""判断内容是否为思考过程"""
thinking_keywords = [
"让我想想", "让我来分析", "让我思考", "我来分析",
"我需要思考", "让我仔细考虑", "思考:", "分析:",
"<think>", "*思考*", "*分析*"
]
content_lower = content.lower()
return any(keyword in content_lower for keyword in thinking_keywords)
def batch_generate(self, model: str, prompts: list, no_thinking: bool = True, **kwargs) -> list:
"""
批量生成文本
Args:
model: 模型名称
prompts: 提示列表
no_thinking: 是否禁用思考过程
**kwargs: 其他参数
"""
results = []
def generate_single(prompt):
try:
return self.generate(model, prompt, stream=False, no_thinking=no_thinking, **kwargs)
except Exception as e:
self.logger.error(f"批量生成失败 - 提示: {prompt[:50]}..., 错误: {e}")
return None
# 使用线程池执行批量请求
with ThreadPoolExecutor(max_workers=self.max_concurrent_requests) as executor:
future_to_prompt = {executor.submit(generate_single, prompt): prompt
for prompt in prompts}
for future in as_completed(future_to_prompt):
prompt = future_to_prompt[future]
try:
result = future.result()
results.append({
"prompt": prompt,
"response": result,
"success": result is not None
})
except Exception as e:
self.logger.error(f"批量请求异常: {e}")
results.append({
"prompt": prompt,
"response": None,
"success": False,
"error": str(e)
})
return results
def health_check(self) -> bool:
"""健康检查"""
try:
response = self._make_request("GET", "/api/tags")
return response.status_code == 200
except Exception as e:
self.logger.error(f"健康检查失败: {e}")
return False
def pull_model(self, model_name: str) -> bool:
"""拉取模型"""
data = {"name": model_name}
try:
response = self._make_request(
"POST",
"/api/pull",
json=data,
stream=True
)
print(f"正在拉取模型 {model_name}...")
for chunk in self._handle_stream_response(response):
if "status" in chunk:
print(f"状态: {chunk['status']}")
if chunk.get("done", False):
print("模型拉取完成!")
return True
except Exception as e:
self.logger.error(f"拉取模型失败: {e}")
return False
def is_model_available(self, model_name: str) -> bool:
"""检查模型是否可用"""
models = self.list_models()
if "models" in models:
return any(model["name"].startswith(model_name) for model in models["models"])
return False
def __enter__(self):
"""上下文管理器入口"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器退出"""
self._executor.shutdown(wait=True)
self.session.close()
# 使用示例
def main():
# 创建客户端,配置重试和并发参数
client = OllamaClient(
max_retries=3,
retry_delay=1.0,
max_concurrent_requests=3,
timeout=30,
enable_logging=True
)
# 健康检查
if not client.health_check():
print("Ollama服务不可用请检查服务状态")
return
# 检查服务是否可用
try:
models = client.list_models()
print("Ollama服务连接成功!")
print(f"可用模型: {[m['name'] for m in models.get('models', [])]}")
except Exception as e:
print(f"无法连接到Ollama服务: {e}")
return
model_name = "Qwen3-8B:latest" # 根据你的模型调整
# 检查模型是否存在,不存在则拉取
if not client.is_model_available(model_name):
print(f"模型 {model_name} 不存在,正在拉取...")
if not client.pull_model(model_name):
print("模型拉取失败")
return
# 示例1: 直接回答模式(无思考过程)
print("\n=== 直接回答模式 ===")
try:
response = client.generate(
model=model_name,
prompt="请用中文简单介绍一下人工智能。",
no_thinking=True, # 启用直接回答模式
temperature=0.7,
max_tokens=200
)
if response:
print(f"回答: {response.get('response', '')}")
else:
print("请求失败")
except Exception as e:
print(f"请求异常: {e}")
# 示例2: 对比 - 有思考过程的回答
print("\n=== 对比:有思考过程的回答 ===")
try:
response = client.generate(
model=model_name,
prompt="请用中文简单介绍一下人工智能。",
no_thinking=False, # 禁用直接回答模式
temperature=0.7,
max_tokens=200
)
if response:
print(f"回答: {response.get('response', '')}")
else:
print("请求失败")
except Exception as e:
print(f"请求异常: {e}")
# 示例3: 流式输出 - 直接回答模式
print("\n=== 流式输出 - 直接回答模式 ===")
try:
stream_response = client.generate(
model=model_name,
prompt="请讲一个简短的故事。",
stream=True,
no_thinking=True, # 启用直接回答模式
temperature=0.8,
max_tokens=200
)
print("流式回答: ", end="")
for chunk in stream_response:
if "response" in chunk:
print(chunk["response"], end="", flush=True)
if chunk.get("done", False):
print("\n")
break
except Exception as e:
print(f"流式输出异常: {e}")
# 示例4: 对话模式 - 直接回答
print("\n=== 对话模式 - 直接回答 ===")
try:
messages = [
{"role": "user", "content": "请介绍一下机器学习的基本概念"},
]
chat_response = client.chat(
model=model_name,
messages=messages,
no_thinking=True, # 启用直接回答模式
temperature=0.7,
max_tokens=200
)
if chat_response:
print(f"AI回答: {chat_response.get('message', {}).get('content', '')}")
except Exception as e:
print(f"对话异常: {e}")
# 示例5: 批量请求测试 - 直接回答模式
print("\n=== 批量请求测试 - 直接回答模式 ===")
prompts = [
"什么是机器学习?",
"什么是深度学习?",
"什么是自然语言处理?",
"什么是计算机视觉?",
"什么是强化学习?"
]
try:
batch_results = client.batch_generate(
model=model_name,
prompts=prompts,
no_thinking=True, # 启用直接回答模式
temperature=0.7,
max_tokens=100
)
for i, result in enumerate(batch_results):
if result["success"]:
print(f"请求 {i+1} 成功: {result['response'].get('response', '')[:100]}...")
else:
print(f"请求 {i+1} 失败: {result.get('error', '未知错误')}")
except Exception as e:
print(f"批量请求异常: {e}")
# 显示统计信息
print("\n=== 请求统计 ===")
stats = client.get_stats()
for key, value in stats.items():
print(f"{key}: {value}")
if __name__ == "__main__":
main()