unsloth/006ollamaApi增加错误处理.py

570 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import json
import time
import threading
from typing import Dict, Any, Iterator, Optional
from functools import wraps
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
class OllamaException(Exception):
"""Ollama自定义异常类"""
pass
class OllamaServerError(OllamaException):
"""服务器错误异常"""
pass
class OllamaRateLimitError(OllamaException):
"""请求限制异常"""
pass
class OllamaClient:
def __init__(self, base_url: str = "http://localhost:11434", max_retries: int = 3,
retry_delay: float = 1.0, max_concurrent_requests: int = 5,
timeout: int = 30, enable_logging: bool = True):
"""
初始化Ollama客户端
Args:
base_url: Ollama服务的基础URL默认为本地11434端口
max_retries: 最大重试次数
retry_delay: 重试间隔(秒)
max_concurrent_requests: 最大并发请求数
timeout: 请求超时时间(秒)
enable_logging: 是否启用日志
"""
self.base_url = base_url
self.max_retries = max_retries
self.retry_delay = retry_delay
self.max_concurrent_requests = max_concurrent_requests
self.timeout = timeout
# 设置日志
if enable_logging:
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
else:
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.CRITICAL)
# 创建会话
self.session = requests.Session()
self.session.timeout = timeout
# 并发控制
self._request_lock = threading.Lock()
self._active_requests = 0
self._executor = ThreadPoolExecutor(max_workers=max_concurrent_requests)
# 请求统计
self._request_count = 0
self._success_count = 0
self._error_count = 0
self._last_request_time = 0
def _wait_for_slot(self):
"""等待可用的请求槽位"""
while self._active_requests >= self.max_concurrent_requests:
time.sleep(0.1)
def _acquire_request_slot(self):
"""获取请求槽位"""
with self._request_lock:
self._wait_for_slot()
self._active_requests += 1
self._request_count += 1
def _release_request_slot(self):
"""释放请求槽位"""
with self._request_lock:
self._active_requests = max(0, self._active_requests - 1)
def _handle_response_error(self, response, attempt: int):
"""处理响应错误"""
if response.status_code == 500:
if attempt < self.max_retries:
wait_time = self.retry_delay * (2 ** attempt) # 指数退避
self.logger.warning(f"服务器错误500{attempt + 1}次重试,等待{wait_time:.1f}秒...")
time.sleep(wait_time)
return True # 需要重试
else:
raise OllamaServerError(f"服务器错误500已达到最大重试次数{self.max_retries}")
elif response.status_code == 429:
# 速率限制
if attempt < self.max_retries:
wait_time = self.retry_delay * (2 ** attempt)
self.logger.warning(f"请求过于频繁(429),第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
time.sleep(wait_time)
return True
else:
raise OllamaRateLimitError(f"请求过于频繁,已达到最大重试次数{self.max_retries}")
elif response.status_code == 503:
# 服务不可用
if attempt < self.max_retries:
wait_time = self.retry_delay * (2 ** attempt)
self.logger.warning(f"服务不可用(503),第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
time.sleep(wait_time)
return True
else:
raise OllamaServerError(f"服务不可用,已达到最大重试次数{self.max_retries}")
else:
# 其他HTTP错误
response.raise_for_status()
return False
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
"""发送请求的通用方法,包含重试逻辑"""
url = f"{self.base_url}{endpoint}"
for attempt in range(self.max_retries + 1):
try:
self._acquire_request_slot()
# 速率限制检查
current_time = time.time()
if current_time - self._last_request_time < 0.1: # 最小间隔100ms
time.sleep(0.1)
self._last_request_time = current_time
self.logger.debug(f"发送请求: {method} {url}, 尝试 {attempt + 1}/{self.max_retries + 1}")
if method.upper() == "GET":
response = self.session.get(url, **kwargs)
else:
response = self.session.post(url, **kwargs)
# 检查响应状态
if response.status_code == 200:
self._success_count += 1
return response
else:
# 处理错误响应
should_retry = self._handle_response_error(response, attempt)
if not should_retry:
break
except (requests.exceptions.ConnectTimeout,
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError) as e:
if attempt < self.max_retries:
wait_time = self.retry_delay * (2 ** attempt)
self.logger.warning(f"网络错误: {e}, 第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
time.sleep(wait_time)
else:
self._error_count += 1
raise OllamaException(f"网络连接失败: {e}")
except requests.exceptions.RequestException as e:
self._error_count += 1
raise OllamaException(f"请求失败: {e}")
finally:
self._release_request_slot()
# 如果所有重试都失败了
self._error_count += 1
raise OllamaException("所有重试都失败了")
def get_stats(self) -> Dict[str, Any]:
"""获取请求统计信息"""
return {
"total_requests": self._request_count,
"successful_requests": self._success_count,
"failed_requests": self._error_count,
"active_requests": self._active_requests,
"success_rate": self._success_count / max(self._request_count, 1) * 100
}
def list_models(self) -> Dict[str, Any]:
"""获取已安装的模型列表"""
try:
response = self._make_request("GET", "/api/tags")
return response.json()
except Exception as e:
self.logger.error(f"获取模型列表失败: {e}")
return {}
def generate(self, model: str, prompt: str, stream: bool = False, **kwargs) -> Any:
"""
生成文本
Args:
model: 模型名称
prompt: 输入提示
stream: 是否流式输出
**kwargs: 其他参数如temperature, top_p等
"""
data = {
"model": model,
"prompt": prompt,
"stream": stream,
**kwargs
}
try:
response = self._make_request(
"POST",
"/api/generate",
json=data,
stream=stream
)
if stream:
return self._handle_stream_response(response)
else:
return response.json()
except Exception as e:
self.logger.error(f"生成请求失败: {e}")
return None
def chat(self, model: str, messages: list, stream: bool = False, **kwargs) -> Any:
"""
对话模式
Args:
model: 模型名称
messages: 消息列表,格式为[{"role": "user", "content": "..."}]
stream: 是否流式输出
**kwargs: 其他参数
"""
data = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs
}
try:
response = self._make_request(
"POST",
"/api/chat",
json=data,
stream=stream
)
if stream:
return self._handle_stream_response(response)
else:
return response.json()
except Exception as e:
self.logger.error(f"对话请求失败: {e}")
return None
def _handle_stream_response(self, response) -> Iterator[Dict[str, Any]]:
"""处理流式响应"""
for line in response.iter_lines():
if line:
try:
yield json.loads(line.decode('utf-8'))
except json.JSONDecodeError:
continue
def batch_generate(self, model: str, prompts: list, **kwargs) -> list:
"""
批量生成文本
Args:
model: 模型名称
prompts: 提示列表
**kwargs: 其他参数
"""
results = []
def generate_single(prompt):
try:
return self.generate(model, prompt, stream=False, **kwargs)
except Exception as e:
self.logger.error(f"批量生成失败 - 提示: {prompt[:50]}..., 错误: {e}")
return None
# 使用线程池执行批量请求
with ThreadPoolExecutor(max_workers=self.max_concurrent_requests) as executor:
future_to_prompt = {executor.submit(generate_single, prompt): prompt
for prompt in prompts}
for future in as_completed(future_to_prompt):
prompt = future_to_prompt[future]
try:
result = future.result()
results.append({
"prompt": prompt,
"response": result,
"success": result is not None
})
except Exception as e:
self.logger.error(f"批量请求异常: {e}")
results.append({
"prompt": prompt,
"response": None,
"success": False,
"error": str(e)
})
return results
def health_check(self) -> bool:
"""健康检查"""
try:
response = self._make_request("GET", "/api/tags")
return response.status_code == 200
except Exception as e:
self.logger.error(f"健康检查失败: {e}")
return False
def pull_model(self, model_name: str) -> bool:
"""拉取模型"""
data = {"name": model_name}
try:
response = self._make_request(
"POST",
"/api/pull",
json=data,
stream=True
)
print(f"正在拉取模型 {model_name}...")
for chunk in self._handle_stream_response(response):
if "status" in chunk:
print(f"状态: {chunk['status']}")
if chunk.get("done", False):
print("模型拉取完成!")
return True
except Exception as e:
self.logger.error(f"拉取模型失败: {e}")
return False
def is_model_available(self, model_name: str) -> bool:
"""检查模型是否可用"""
models = self.list_models()
if "models" in models:
return any(model["name"].startswith(model_name) for model in models["models"])
return False
def __enter__(self):
"""上下文管理器入口"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器退出"""
self._executor.shutdown(wait=True)
self.session.close()
# 使用示例
def main():
# 创建客户端,配置重试和并发参数
client = OllamaClient(
max_retries=3,
retry_delay=1.0,
max_concurrent_requests=3,
timeout=30,
enable_logging=True
)
# 健康检查
if not client.health_check():
print("Ollama服务不可用请检查服务状态")
return
# 检查服务是否可用
try:
models = client.list_models()
print("Ollama服务连接成功!")
print(f"可用模型: {[m['name'] for m in models.get('models', [])]}")
except Exception as e:
print(f"无法连接到Ollama服务: {e}")
return
model_name = "Qwen3-8B:latest" # 根据你的模型调整
# 检查模型是否存在,不存在则拉取
if not client.is_model_available(model_name):
print(f"模型 {model_name} 不存在,正在拉取...")
if not client.pull_model(model_name):
print("模型拉取失败")
return
# 示例1: 处理500错误的单个请求
# print("\n=== 单个请求(含错误处理) ===")
# try:
# response = client.generate(
# model=model_name,
# prompt="请用中文简单介绍一下人工智能。",
# no_thinking=True,
# temperature=0.7,
# max_tokens=200
# )
# if response:
# print(f"回答: {response.get('response', '')}")
# else:
# print("请求失败")
# except Exception as e:
# print(f"请求异常: {e}")
# 示例2: 流式输出
print("\n=== 流式输出 ===")
try:
stream_response = client.generate(
model=model_name,
prompt="请讲一个简短的故事。",
stream=True,
no_thinking=True,
enable_thinking = False,
temperature=0.8,
max_tokens=200
)
print("流式回答: ", end="")
for chunk in stream_response:
if "response" in chunk:
print(chunk["response"], end="", flush=True)
if chunk.get("done", False):
print("\n")
break
except Exception as e:
print(f"流式输出异常: {e}")
# 示例3: 对话模式
print("\n=== 对话模式 ===")
try:
messages = [
{"role" : "user", "content" : "请介绍一下昊天"},
]
chat_response = client.chat(
model=model_name,
messages=messages,
# no_thinking=True,
# # no_thinking=True,
# enable_thinking = True,
temperature=0.7,
max_tokens=200
)
if chat_response:
print(f"AI回答: {chat_response.get('message', {}).get('content', '')}")
except Exception as e:
print(f"对话异常: {e}")
# 示例4: 批量请求测试
print("\n=== 批量请求测试 ===")
prompts = [
"什么是机器学习?",
"什么是深度学习?",
"什么是自然语言处理?",
"什么是计算机视觉?",
"什么是强化学习?"
]
try:
batch_results = client.batch_generate(
model=model_name,
prompts=prompts,
temperature=0.7,
max_tokens=100
)
for i, result in enumerate(batch_results):
if result["success"]:
print(f"请求 {i+1} 成功: {result['response'].get('response', '')[:50]}...")
else:
print(f"请求 {i+1} 失败: {result.get('error', '未知错误')}")
except Exception as e:
print(f"批量请求异常: {e}")
# 示例5: 并发压力测试
print("\n=== 并发压力测试 ===")
def stress_test():
import concurrent.futures
def single_request(i):
try:
response = client.generate(
model=model_name,
prompt=f"请简单回答:什么是人工智能?(请求{i})",
temperature=0.7,
max_tokens=50
)
return f"请求{i}: 成功" if response else f"请求{i}: 失败"
except Exception as e:
return f"请求{i}: 异常 - {e}"
# 发送10个并发请求
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(single_request, i) for i in range(10)]
for future in concurrent.futures.as_completed(futures):
try:
result = future.result()
print(result)
except Exception as e:
print(f"并发请求异常: {e}")
stress_test()
# 显示统计信息
print("\n=== 请求统计 ===")
stats = client.get_stats()
for key, value in stats.items():
print(f"{key}: {value}")
def demo_error_handling():
"""演示错误处理"""
print("\n=== 错误处理演示 ===")
# 模拟连接错误的客户端
client = OllamaClient(
base_url="http://localhost:99999", # 无效端口
max_retries=2,
retry_delay=0.5
)
try:
response = client.generate(
model="llama2",
prompt="测试请求",
temperature=0.7
)
print(f"意外成功: {response}")
except Exception as e:
print(f"预期的错误: {e}")
# 显示错误统计
stats = client.get_stats()
print(f"错误统计: {stats}")
def demo_context_manager():
"""演示上下文管理器使用"""
print("\n=== 上下文管理器演示 ===")
with OllamaClient(max_retries=2) as client:
if client.health_check():
print("服务健康检查通过")
# 使用客户端进行请求
models = client.list_models()
print(f"可用模型数量: {len(models.get('models', []))}")
else:
print("服务健康检查失败")
if __name__ == "__main__":
main()
demo_error_handling()
demo_context_manager()