648 lines
23 KiB
Python
648 lines
23 KiB
Python
import requests
|
||
import json
|
||
import time
|
||
import threading
|
||
from typing import Dict, Any, Iterator, Optional
|
||
from functools import wraps
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import logging
|
||
|
||
class OllamaException(Exception):
|
||
"""Ollama自定义异常类"""
|
||
pass
|
||
|
||
class OllamaServerError(OllamaException):
|
||
"""服务器错误异常"""
|
||
pass
|
||
|
||
class OllamaRateLimitError(OllamaException):
|
||
"""请求限制异常"""
|
||
pass
|
||
|
||
class OllamaClient:
|
||
def __init__(self, base_url: str = "http://localhost:11434", max_retries: int = 3,
|
||
retry_delay: float = 1.0, max_concurrent_requests: int = 5,
|
||
timeout: int = 30, enable_logging: bool = True):
|
||
"""
|
||
初始化Ollama客户端
|
||
|
||
Args:
|
||
base_url: Ollama服务的基础URL,默认为本地11434端口
|
||
max_retries: 最大重试次数
|
||
retry_delay: 重试间隔(秒)
|
||
max_concurrent_requests: 最大并发请求数
|
||
timeout: 请求超时时间(秒)
|
||
enable_logging: 是否启用日志
|
||
"""
|
||
self.base_url = base_url
|
||
self.max_retries = max_retries
|
||
self.retry_delay = retry_delay
|
||
self.max_concurrent_requests = max_concurrent_requests
|
||
self.timeout = timeout
|
||
|
||
# 设置日志
|
||
if enable_logging:
|
||
logging.basicConfig(level=logging.INFO)
|
||
self.logger = logging.getLogger(__name__)
|
||
else:
|
||
self.logger = logging.getLogger(__name__)
|
||
self.logger.setLevel(logging.CRITICAL)
|
||
|
||
# 创建会话
|
||
self.session = requests.Session()
|
||
self.session.timeout = timeout
|
||
|
||
# 并发控制
|
||
self._request_lock = threading.Lock()
|
||
self._active_requests = 0
|
||
self._executor = ThreadPoolExecutor(max_workers=max_concurrent_requests)
|
||
|
||
# 请求统计
|
||
self._request_count = 0
|
||
self._success_count = 0
|
||
self._error_count = 0
|
||
self._last_request_time = 0
|
||
|
||
def _wait_for_slot(self):
|
||
"""等待可用的请求槽位"""
|
||
while self._active_requests >= self.max_concurrent_requests:
|
||
time.sleep(0.1)
|
||
|
||
def _acquire_request_slot(self):
|
||
"""获取请求槽位"""
|
||
with self._request_lock:
|
||
self._wait_for_slot()
|
||
self._active_requests += 1
|
||
self._request_count += 1
|
||
|
||
def _release_request_slot(self):
|
||
"""释放请求槽位"""
|
||
with self._request_lock:
|
||
self._active_requests = max(0, self._active_requests - 1)
|
||
|
||
def _handle_response_error(self, response, attempt: int):
|
||
"""处理响应错误"""
|
||
if response.status_code == 500:
|
||
if attempt < self.max_retries:
|
||
wait_time = self.retry_delay * (2 ** attempt) # 指数退避
|
||
self.logger.warning(f"服务器错误500,第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
|
||
time.sleep(wait_time)
|
||
return True # 需要重试
|
||
else:
|
||
raise OllamaServerError(f"服务器错误500,已达到最大重试次数{self.max_retries}")
|
||
|
||
elif response.status_code == 429:
|
||
# 速率限制
|
||
if attempt < self.max_retries:
|
||
wait_time = self.retry_delay * (2 ** attempt)
|
||
self.logger.warning(f"请求过于频繁(429),第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
|
||
time.sleep(wait_time)
|
||
return True
|
||
else:
|
||
raise OllamaRateLimitError(f"请求过于频繁,已达到最大重试次数{self.max_retries}")
|
||
|
||
elif response.status_code == 503:
|
||
# 服务不可用
|
||
if attempt < self.max_retries:
|
||
wait_time = self.retry_delay * (2 ** attempt)
|
||
self.logger.warning(f"服务不可用(503),第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
|
||
time.sleep(wait_time)
|
||
return True
|
||
else:
|
||
raise OllamaServerError(f"服务不可用,已达到最大重试次数{self.max_retries}")
|
||
|
||
else:
|
||
# 其他HTTP错误
|
||
response.raise_for_status()
|
||
|
||
return False
|
||
|
||
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
|
||
"""发送请求的通用方法,包含重试逻辑"""
|
||
url = f"{self.base_url}{endpoint}"
|
||
|
||
for attempt in range(self.max_retries + 1):
|
||
try:
|
||
self._acquire_request_slot()
|
||
|
||
# 速率限制检查
|
||
current_time = time.time()
|
||
if current_time - self._last_request_time < 0.1: # 最小间隔100ms
|
||
time.sleep(0.1)
|
||
|
||
self._last_request_time = current_time
|
||
|
||
self.logger.debug(f"发送请求: {method} {url}, 尝试 {attempt + 1}/{self.max_retries + 1}")
|
||
|
||
if method.upper() == "GET":
|
||
response = self.session.get(url, **kwargs)
|
||
else:
|
||
response = self.session.post(url, **kwargs)
|
||
|
||
# 检查响应状态
|
||
if response.status_code == 200:
|
||
self._success_count += 1
|
||
return response
|
||
else:
|
||
# 处理错误响应
|
||
should_retry = self._handle_response_error(response, attempt)
|
||
if not should_retry:
|
||
break
|
||
|
||
except (requests.exceptions.ConnectTimeout,
|
||
requests.exceptions.ReadTimeout,
|
||
requests.exceptions.ConnectionError) as e:
|
||
|
||
if attempt < self.max_retries:
|
||
wait_time = self.retry_delay * (2 ** attempt)
|
||
self.logger.warning(f"网络错误: {e}, 第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
|
||
time.sleep(wait_time)
|
||
else:
|
||
self._error_count += 1
|
||
raise OllamaException(f"网络连接失败: {e}")
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
self._error_count += 1
|
||
raise OllamaException(f"请求失败: {e}")
|
||
|
||
finally:
|
||
self._release_request_slot()
|
||
|
||
# 如果所有重试都失败了
|
||
self._error_count += 1
|
||
raise OllamaException("所有重试都失败了")
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取请求统计信息"""
|
||
return {
|
||
"total_requests": self._request_count,
|
||
"successful_requests": self._success_count,
|
||
"failed_requests": self._error_count,
|
||
"active_requests": self._active_requests,
|
||
"success_rate": self._success_count / max(self._request_count, 1) * 100
|
||
}
|
||
|
||
def list_models(self) -> Dict[str, Any]:
|
||
"""获取已安装的模型列表"""
|
||
try:
|
||
response = self._make_request("GET", "/api/tags")
|
||
return response.json()
|
||
except Exception as e:
|
||
self.logger.error(f"获取模型列表失败: {e}")
|
||
return {}
|
||
|
||
def generate(self, model: str, prompt: str, stream: bool = False,
|
||
no_thinking: bool = True, **kwargs) -> Any:
|
||
"""
|
||
生成文本
|
||
|
||
Args:
|
||
model: 模型名称
|
||
prompt: 输入提示
|
||
stream: 是否流式输出
|
||
no_thinking: 是否禁用思考过程,直接回答
|
||
**kwargs: 其他参数如temperature, top_p等
|
||
"""
|
||
# 如果启用no_thinking,修改prompt以指示模型直接回答
|
||
if no_thinking:
|
||
prompt = self._format_no_thinking_prompt(prompt)
|
||
|
||
data = {
|
||
"model": model,
|
||
"prompt": prompt,
|
||
"stream": stream,
|
||
**kwargs
|
||
}
|
||
|
||
# 添加直接回答的系统参数
|
||
if no_thinking:
|
||
data.update({
|
||
"system": "你是一个直接、简洁的AI助手。请直接回答问题,不要显示思考过程或分析步骤。",
|
||
"temperature": kwargs.get("temperature", 0.7),
|
||
"top_p": kwargs.get("top_p", 0.9),
|
||
"repeat_penalty": kwargs.get("repeat_penalty", 1.1)
|
||
})
|
||
|
||
try:
|
||
response = self._make_request(
|
||
"POST",
|
||
"/api/generate",
|
||
json=data,
|
||
stream=stream
|
||
)
|
||
|
||
if stream:
|
||
return self._handle_stream_response(response, filter_thinking=no_thinking)
|
||
else:
|
||
result = response.json()
|
||
if no_thinking and result:
|
||
result = self._filter_thinking_from_response(result)
|
||
return result
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"生成请求失败: {e}")
|
||
return None
|
||
|
||
def chat(self, model: str, messages: list, stream: bool = False,
|
||
no_thinking: bool = True, **kwargs) -> Any:
|
||
"""
|
||
对话模式
|
||
|
||
Args:
|
||
model: 模型名称
|
||
messages: 消息列表,格式为[{"role": "user", "content": "..."}]
|
||
stream: 是否流式输出
|
||
no_thinking: 是否禁用思考过程,直接回答
|
||
**kwargs: 其他参数
|
||
"""
|
||
# 如果启用no_thinking,在消息中添加系统提示
|
||
if no_thinking:
|
||
messages = self._format_no_thinking_messages(messages)
|
||
|
||
data = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"stream": stream,
|
||
**kwargs
|
||
}
|
||
|
||
# 添加直接回答的参数
|
||
if no_thinking:
|
||
data.update({
|
||
"system": "你是一个直接、简洁的AI助手。请直接回答问题,不要显示思考过程或分析步骤。",
|
||
"temperature": kwargs.get("temperature", 0.7),
|
||
"top_p": kwargs.get("top_p", 0.9),
|
||
"repeat_penalty": kwargs.get("repeat_penalty", 1.1)
|
||
})
|
||
|
||
try:
|
||
response = self._make_request(
|
||
"POST",
|
||
"/api/chat",
|
||
json=data,
|
||
stream=stream
|
||
)
|
||
|
||
if stream:
|
||
return self._handle_stream_response(response, filter_thinking=no_thinking)
|
||
else:
|
||
result = response.json()
|
||
if no_thinking and result:
|
||
result = self._filter_thinking_from_chat_response(result)
|
||
return result
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"对话请求失败: {e}")
|
||
return None
|
||
|
||
def _format_no_thinking_prompt(self, prompt: str) -> str:
|
||
"""格式化prompt以避免显示思考过程"""
|
||
return f"""请直接回答以下问题,不要显示思考过程、分析步骤或推理过程:
|
||
|
||
{prompt}
|
||
|
||
注意:请直接给出答案,不要包含"让我想想"、"分析一下"、"首先"、"然后"等思考过程的表述。"""
|
||
|
||
def _format_no_thinking_messages(self, messages: list) -> list:
|
||
"""格式化消息以避免显示思考过程"""
|
||
formatted_messages = []
|
||
|
||
# 添加系统消息
|
||
system_message = {
|
||
"role": "system",
|
||
"content": "你是一个直接、简洁的AI助手。请直接回答问题,不要显示思考过程、分析步骤或推理过程。不要使用'让我想想'、'分析一下'、'首先'、'然后'等表述。"
|
||
}
|
||
formatted_messages.append(system_message)
|
||
|
||
# 添加原始消息
|
||
for message in messages:
|
||
if message.get("role") == "user":
|
||
content = message.get("content", "")
|
||
formatted_content = f"请直接回答:{content}"
|
||
formatted_messages.append({
|
||
"role": "user",
|
||
"content": formatted_content
|
||
})
|
||
else:
|
||
formatted_messages.append(message)
|
||
|
||
return formatted_messages
|
||
|
||
def _filter_thinking_from_response(self, response: dict) -> dict:
|
||
"""从响应中过滤掉思考过程"""
|
||
if "response" in response:
|
||
content = response["response"]
|
||
# 过滤常见的思考过程表述
|
||
filtered_content = self._filter_thinking_text(content)
|
||
response["response"] = filtered_content
|
||
return response
|
||
|
||
def _filter_thinking_from_chat_response(self, response: dict) -> dict:
|
||
"""从对话响应中过滤掉思考过程"""
|
||
if "message" in response and "content" in response["message"]:
|
||
content = response["message"]["content"]
|
||
# 过滤常见的思考过程表述
|
||
filtered_content = self._filter_thinking_text(content)
|
||
response["message"]["content"] = filtered_content
|
||
return response
|
||
|
||
def _filter_thinking_text(self, text: str) -> str:
|
||
"""过滤文本中的思考过程表述"""
|
||
# 常见的思考过程表述
|
||
thinking_patterns = [
|
||
r"让我想想[。,\n]*",
|
||
r"让我来分析一下[。,\n]*",
|
||
r"让我来思考一下[。,\n]*",
|
||
r"首先[,。]*让我[^。]*[。,\n]*",
|
||
r"我来分析一下[。,\n]*",
|
||
r"我需要思考一下[。,\n]*",
|
||
r"让我仔细考虑一下[。,\n]*",
|
||
r"这个问题需要[^。]*分析[。,\n]*",
|
||
r"思考:[^\n]*\n*",
|
||
r"分析:[^\n]*\n*",
|
||
r"<thinking>.*?</thinking>",
|
||
r"\*思考\*[^\n]*\n*",
|
||
r"\*分析\*[^\n]*\n*"
|
||
]
|
||
|
||
import re
|
||
filtered_text = text
|
||
for pattern in thinking_patterns:
|
||
filtered_text = re.sub(pattern, "", filtered_text, flags=re.DOTALL | re.IGNORECASE)
|
||
|
||
# 清理多余的空行和空格
|
||
filtered_text = re.sub(r'\n\s*\n', '\n\n', filtered_text)
|
||
filtered_text = filtered_text.strip()
|
||
|
||
return filtered_text
|
||
|
||
def _handle_stream_response(self, response, filter_thinking: bool = False) -> Iterator[Dict[str, Any]]:
|
||
"""处理流式响应"""
|
||
for line in response.iter_lines():
|
||
if line:
|
||
try:
|
||
chunk = json.loads(line.decode('utf-8'))
|
||
if filter_thinking and "response" in chunk:
|
||
# 实时过滤思考过程
|
||
content = chunk["response"]
|
||
if not self._is_thinking_content(content):
|
||
yield chunk
|
||
else:
|
||
yield chunk
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
def _is_thinking_content(self, content: str) -> bool:
|
||
"""判断内容是否为思考过程"""
|
||
thinking_keywords = [
|
||
"让我想想", "让我来分析", "让我思考", "我来分析",
|
||
"我需要思考", "让我仔细考虑", "思考:", "分析:",
|
||
"<think>", "*思考*", "*分析*"
|
||
]
|
||
|
||
content_lower = content.lower()
|
||
return any(keyword in content_lower for keyword in thinking_keywords)
|
||
|
||
def batch_generate(self, model: str, prompts: list, no_thinking: bool = True, **kwargs) -> list:
|
||
"""
|
||
批量生成文本
|
||
|
||
Args:
|
||
model: 模型名称
|
||
prompts: 提示列表
|
||
no_thinking: 是否禁用思考过程
|
||
**kwargs: 其他参数
|
||
"""
|
||
results = []
|
||
|
||
def generate_single(prompt):
|
||
try:
|
||
return self.generate(model, prompt, stream=False, no_thinking=no_thinking, **kwargs)
|
||
except Exception as e:
|
||
self.logger.error(f"批量生成失败 - 提示: {prompt[:50]}..., 错误: {e}")
|
||
return None
|
||
|
||
# 使用线程池执行批量请求
|
||
with ThreadPoolExecutor(max_workers=self.max_concurrent_requests) as executor:
|
||
future_to_prompt = {executor.submit(generate_single, prompt): prompt
|
||
for prompt in prompts}
|
||
|
||
for future in as_completed(future_to_prompt):
|
||
prompt = future_to_prompt[future]
|
||
try:
|
||
result = future.result()
|
||
results.append({
|
||
"prompt": prompt,
|
||
"response": result,
|
||
"success": result is not None
|
||
})
|
||
except Exception as e:
|
||
self.logger.error(f"批量请求异常: {e}")
|
||
results.append({
|
||
"prompt": prompt,
|
||
"response": None,
|
||
"success": False,
|
||
"error": str(e)
|
||
})
|
||
|
||
return results
|
||
|
||
def health_check(self) -> bool:
|
||
"""健康检查"""
|
||
try:
|
||
response = self._make_request("GET", "/api/tags")
|
||
return response.status_code == 200
|
||
except Exception as e:
|
||
self.logger.error(f"健康检查失败: {e}")
|
||
return False
|
||
|
||
def pull_model(self, model_name: str) -> bool:
|
||
"""拉取模型"""
|
||
data = {"name": model_name}
|
||
|
||
try:
|
||
response = self._make_request(
|
||
"POST",
|
||
"/api/pull",
|
||
json=data,
|
||
stream=True
|
||
)
|
||
|
||
print(f"正在拉取模型 {model_name}...")
|
||
for chunk in self._handle_stream_response(response):
|
||
if "status" in chunk:
|
||
print(f"状态: {chunk['status']}")
|
||
if chunk.get("done", False):
|
||
print("模型拉取完成!")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"拉取模型失败: {e}")
|
||
return False
|
||
|
||
def is_model_available(self, model_name: str) -> bool:
|
||
"""检查模型是否可用"""
|
||
models = self.list_models()
|
||
if "models" in models:
|
||
return any(model["name"].startswith(model_name) for model in models["models"])
|
||
return False
|
||
|
||
def __enter__(self):
|
||
"""上下文管理器入口"""
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
"""上下文管理器退出"""
|
||
self._executor.shutdown(wait=True)
|
||
self.session.close()
|
||
|
||
|
||
# 使用示例
|
||
def main():
|
||
# 创建客户端,配置重试和并发参数
|
||
client = OllamaClient(
|
||
max_retries=3,
|
||
retry_delay=1.0,
|
||
max_concurrent_requests=3,
|
||
timeout=30,
|
||
enable_logging=True
|
||
)
|
||
|
||
# 健康检查
|
||
if not client.health_check():
|
||
print("Ollama服务不可用,请检查服务状态")
|
||
return
|
||
|
||
# 检查服务是否可用
|
||
try:
|
||
models = client.list_models()
|
||
print("Ollama服务连接成功!")
|
||
print(f"可用模型: {[m['name'] for m in models.get('models', [])]}")
|
||
except Exception as e:
|
||
print(f"无法连接到Ollama服务: {e}")
|
||
return
|
||
|
||
model_name = "Qwen3-8B:latest" # 根据你的模型调整
|
||
|
||
# 检查模型是否存在,不存在则拉取
|
||
if not client.is_model_available(model_name):
|
||
print(f"模型 {model_name} 不存在,正在拉取...")
|
||
if not client.pull_model(model_name):
|
||
print("模型拉取失败")
|
||
return
|
||
|
||
# 示例1: 直接回答模式(无思考过程)
|
||
print("\n=== 直接回答模式 ===")
|
||
try:
|
||
response = client.generate(
|
||
model=model_name,
|
||
prompt="请用中文简单介绍一下人工智能。",
|
||
no_thinking=True, # 启用直接回答模式
|
||
temperature=0.7,
|
||
max_tokens=200
|
||
)
|
||
|
||
if response:
|
||
print(f"回答: {response.get('response', '')}")
|
||
else:
|
||
print("请求失败")
|
||
except Exception as e:
|
||
print(f"请求异常: {e}")
|
||
|
||
# 示例2: 对比 - 有思考过程的回答
|
||
print("\n=== 对比:有思考过程的回答 ===")
|
||
try:
|
||
response = client.generate(
|
||
model=model_name,
|
||
prompt="请用中文简单介绍一下人工智能。",
|
||
no_thinking=False, # 禁用直接回答模式
|
||
temperature=0.7,
|
||
max_tokens=200
|
||
)
|
||
|
||
if response:
|
||
print(f"回答: {response.get('response', '')}")
|
||
else:
|
||
print("请求失败")
|
||
except Exception as e:
|
||
print(f"请求异常: {e}")
|
||
|
||
# 示例3: 流式输出 - 直接回答模式
|
||
print("\n=== 流式输出 - 直接回答模式 ===")
|
||
try:
|
||
stream_response = client.generate(
|
||
model=model_name,
|
||
prompt="请讲一个简短的故事。",
|
||
stream=True,
|
||
no_thinking=True, # 启用直接回答模式
|
||
temperature=0.8,
|
||
max_tokens=200
|
||
)
|
||
|
||
print("流式回答: ", end="")
|
||
for chunk in stream_response:
|
||
if "response" in chunk:
|
||
print(chunk["response"], end="", flush=True)
|
||
if chunk.get("done", False):
|
||
print("\n")
|
||
break
|
||
except Exception as e:
|
||
print(f"流式输出异常: {e}")
|
||
|
||
# 示例4: 对话模式 - 直接回答
|
||
print("\n=== 对话模式 - 直接回答 ===")
|
||
try:
|
||
messages = [
|
||
{"role": "user", "content": "请介绍一下机器学习的基本概念"},
|
||
]
|
||
|
||
chat_response = client.chat(
|
||
model=model_name,
|
||
messages=messages,
|
||
no_thinking=True, # 启用直接回答模式
|
||
temperature=0.7,
|
||
max_tokens=200
|
||
)
|
||
|
||
if chat_response:
|
||
print(f"AI回答: {chat_response.get('message', {}).get('content', '')}")
|
||
except Exception as e:
|
||
print(f"对话异常: {e}")
|
||
|
||
# 示例5: 批量请求测试 - 直接回答模式
|
||
print("\n=== 批量请求测试 - 直接回答模式 ===")
|
||
prompts = [
|
||
"什么是机器学习?",
|
||
"什么是深度学习?",
|
||
"什么是自然语言处理?",
|
||
"什么是计算机视觉?",
|
||
"什么是强化学习?"
|
||
]
|
||
|
||
try:
|
||
batch_results = client.batch_generate(
|
||
model=model_name,
|
||
prompts=prompts,
|
||
no_thinking=True, # 启用直接回答模式
|
||
temperature=0.7,
|
||
max_tokens=100
|
||
)
|
||
|
||
for i, result in enumerate(batch_results):
|
||
if result["success"]:
|
||
print(f"请求 {i+1} 成功: {result['response'].get('response', '')[:100]}...")
|
||
else:
|
||
print(f"请求 {i+1} 失败: {result.get('error', '未知错误')}")
|
||
|
||
except Exception as e:
|
||
print(f"批量请求异常: {e}")
|
||
|
||
# 显示统计信息
|
||
print("\n=== 请求统计 ===")
|
||
stats = client.get_stats()
|
||
for key, value in stats.items():
|
||
print(f"{key}: {value}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |