570 lines
19 KiB
Python
570 lines
19 KiB
Python
import requests
|
||
import json
|
||
import time
|
||
import threading
|
||
from typing import Dict, Any, Iterator, Optional
|
||
from functools import wraps
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import logging
|
||
|
||
class OllamaException(Exception):
|
||
"""Ollama自定义异常类"""
|
||
pass
|
||
|
||
class OllamaServerError(OllamaException):
|
||
"""服务器错误异常"""
|
||
pass
|
||
|
||
class OllamaRateLimitError(OllamaException):
|
||
"""请求限制异常"""
|
||
pass
|
||
|
||
class OllamaClient:
|
||
def __init__(self, base_url: str = "http://localhost:11434", max_retries: int = 3,
|
||
retry_delay: float = 1.0, max_concurrent_requests: int = 5,
|
||
timeout: int = 30, enable_logging: bool = True):
|
||
"""
|
||
初始化Ollama客户端
|
||
|
||
Args:
|
||
base_url: Ollama服务的基础URL,默认为本地11434端口
|
||
max_retries: 最大重试次数
|
||
retry_delay: 重试间隔(秒)
|
||
max_concurrent_requests: 最大并发请求数
|
||
timeout: 请求超时时间(秒)
|
||
enable_logging: 是否启用日志
|
||
"""
|
||
self.base_url = base_url
|
||
self.max_retries = max_retries
|
||
self.retry_delay = retry_delay
|
||
self.max_concurrent_requests = max_concurrent_requests
|
||
self.timeout = timeout
|
||
|
||
# 设置日志
|
||
if enable_logging:
|
||
logging.basicConfig(level=logging.INFO)
|
||
self.logger = logging.getLogger(__name__)
|
||
else:
|
||
self.logger = logging.getLogger(__name__)
|
||
self.logger.setLevel(logging.CRITICAL)
|
||
|
||
# 创建会话
|
||
self.session = requests.Session()
|
||
self.session.timeout = timeout
|
||
|
||
# 并发控制
|
||
self._request_lock = threading.Lock()
|
||
self._active_requests = 0
|
||
self._executor = ThreadPoolExecutor(max_workers=max_concurrent_requests)
|
||
|
||
# 请求统计
|
||
self._request_count = 0
|
||
self._success_count = 0
|
||
self._error_count = 0
|
||
self._last_request_time = 0
|
||
|
||
def _wait_for_slot(self):
|
||
"""等待可用的请求槽位"""
|
||
while self._active_requests >= self.max_concurrent_requests:
|
||
time.sleep(0.1)
|
||
|
||
def _acquire_request_slot(self):
|
||
"""获取请求槽位"""
|
||
with self._request_lock:
|
||
self._wait_for_slot()
|
||
self._active_requests += 1
|
||
self._request_count += 1
|
||
|
||
def _release_request_slot(self):
|
||
"""释放请求槽位"""
|
||
with self._request_lock:
|
||
self._active_requests = max(0, self._active_requests - 1)
|
||
|
||
def _handle_response_error(self, response, attempt: int):
|
||
"""处理响应错误"""
|
||
if response.status_code == 500:
|
||
if attempt < self.max_retries:
|
||
wait_time = self.retry_delay * (2 ** attempt) # 指数退避
|
||
self.logger.warning(f"服务器错误500,第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
|
||
time.sleep(wait_time)
|
||
return True # 需要重试
|
||
else:
|
||
raise OllamaServerError(f"服务器错误500,已达到最大重试次数{self.max_retries}")
|
||
|
||
elif response.status_code == 429:
|
||
# 速率限制
|
||
if attempt < self.max_retries:
|
||
wait_time = self.retry_delay * (2 ** attempt)
|
||
self.logger.warning(f"请求过于频繁(429),第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
|
||
time.sleep(wait_time)
|
||
return True
|
||
else:
|
||
raise OllamaRateLimitError(f"请求过于频繁,已达到最大重试次数{self.max_retries}")
|
||
|
||
elif response.status_code == 503:
|
||
# 服务不可用
|
||
if attempt < self.max_retries:
|
||
wait_time = self.retry_delay * (2 ** attempt)
|
||
self.logger.warning(f"服务不可用(503),第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
|
||
time.sleep(wait_time)
|
||
return True
|
||
else:
|
||
raise OllamaServerError(f"服务不可用,已达到最大重试次数{self.max_retries}")
|
||
|
||
else:
|
||
# 其他HTTP错误
|
||
response.raise_for_status()
|
||
|
||
return False
|
||
|
||
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
|
||
"""发送请求的通用方法,包含重试逻辑"""
|
||
url = f"{self.base_url}{endpoint}"
|
||
|
||
for attempt in range(self.max_retries + 1):
|
||
try:
|
||
self._acquire_request_slot()
|
||
|
||
# 速率限制检查
|
||
current_time = time.time()
|
||
if current_time - self._last_request_time < 0.1: # 最小间隔100ms
|
||
time.sleep(0.1)
|
||
|
||
self._last_request_time = current_time
|
||
|
||
self.logger.debug(f"发送请求: {method} {url}, 尝试 {attempt + 1}/{self.max_retries + 1}")
|
||
|
||
if method.upper() == "GET":
|
||
response = self.session.get(url, **kwargs)
|
||
else:
|
||
response = self.session.post(url, **kwargs)
|
||
|
||
# 检查响应状态
|
||
if response.status_code == 200:
|
||
self._success_count += 1
|
||
return response
|
||
else:
|
||
# 处理错误响应
|
||
should_retry = self._handle_response_error(response, attempt)
|
||
if not should_retry:
|
||
break
|
||
|
||
except (requests.exceptions.ConnectTimeout,
|
||
requests.exceptions.ReadTimeout,
|
||
requests.exceptions.ConnectionError) as e:
|
||
|
||
if attempt < self.max_retries:
|
||
wait_time = self.retry_delay * (2 ** attempt)
|
||
self.logger.warning(f"网络错误: {e}, 第{attempt + 1}次重试,等待{wait_time:.1f}秒...")
|
||
time.sleep(wait_time)
|
||
else:
|
||
self._error_count += 1
|
||
raise OllamaException(f"网络连接失败: {e}")
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
self._error_count += 1
|
||
raise OllamaException(f"请求失败: {e}")
|
||
|
||
finally:
|
||
self._release_request_slot()
|
||
|
||
# 如果所有重试都失败了
|
||
self._error_count += 1
|
||
raise OllamaException("所有重试都失败了")
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取请求统计信息"""
|
||
return {
|
||
"total_requests": self._request_count,
|
||
"successful_requests": self._success_count,
|
||
"failed_requests": self._error_count,
|
||
"active_requests": self._active_requests,
|
||
"success_rate": self._success_count / max(self._request_count, 1) * 100
|
||
}
|
||
|
||
def list_models(self) -> Dict[str, Any]:
|
||
"""获取已安装的模型列表"""
|
||
try:
|
||
response = self._make_request("GET", "/api/tags")
|
||
return response.json()
|
||
except Exception as e:
|
||
self.logger.error(f"获取模型列表失败: {e}")
|
||
return {}
|
||
|
||
def generate(self, model: str, prompt: str, stream: bool = False, **kwargs) -> Any:
|
||
"""
|
||
生成文本
|
||
|
||
Args:
|
||
model: 模型名称
|
||
prompt: 输入提示
|
||
stream: 是否流式输出
|
||
**kwargs: 其他参数如temperature, top_p等
|
||
"""
|
||
data = {
|
||
"model": model,
|
||
"prompt": prompt,
|
||
"stream": stream,
|
||
**kwargs
|
||
}
|
||
|
||
try:
|
||
response = self._make_request(
|
||
"POST",
|
||
"/api/generate",
|
||
json=data,
|
||
stream=stream
|
||
)
|
||
|
||
if stream:
|
||
return self._handle_stream_response(response)
|
||
else:
|
||
return response.json()
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"生成请求失败: {e}")
|
||
return None
|
||
|
||
def chat(self, model: str, messages: list, stream: bool = False, **kwargs) -> Any:
|
||
"""
|
||
对话模式
|
||
|
||
Args:
|
||
model: 模型名称
|
||
messages: 消息列表,格式为[{"role": "user", "content": "..."}]
|
||
stream: 是否流式输出
|
||
**kwargs: 其他参数
|
||
"""
|
||
data = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"stream": stream,
|
||
**kwargs
|
||
}
|
||
|
||
try:
|
||
response = self._make_request(
|
||
"POST",
|
||
"/api/chat",
|
||
json=data,
|
||
stream=stream
|
||
)
|
||
|
||
if stream:
|
||
return self._handle_stream_response(response)
|
||
else:
|
||
return response.json()
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"对话请求失败: {e}")
|
||
return None
|
||
|
||
def _handle_stream_response(self, response) -> Iterator[Dict[str, Any]]:
|
||
"""处理流式响应"""
|
||
for line in response.iter_lines():
|
||
if line:
|
||
try:
|
||
yield json.loads(line.decode('utf-8'))
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
def batch_generate(self, model: str, prompts: list, **kwargs) -> list:
|
||
"""
|
||
批量生成文本
|
||
|
||
Args:
|
||
model: 模型名称
|
||
prompts: 提示列表
|
||
**kwargs: 其他参数
|
||
"""
|
||
results = []
|
||
|
||
def generate_single(prompt):
|
||
try:
|
||
return self.generate(model, prompt, stream=False, **kwargs)
|
||
except Exception as e:
|
||
self.logger.error(f"批量生成失败 - 提示: {prompt[:50]}..., 错误: {e}")
|
||
return None
|
||
|
||
# 使用线程池执行批量请求
|
||
with ThreadPoolExecutor(max_workers=self.max_concurrent_requests) as executor:
|
||
future_to_prompt = {executor.submit(generate_single, prompt): prompt
|
||
for prompt in prompts}
|
||
|
||
for future in as_completed(future_to_prompt):
|
||
prompt = future_to_prompt[future]
|
||
try:
|
||
result = future.result()
|
||
results.append({
|
||
"prompt": prompt,
|
||
"response": result,
|
||
"success": result is not None
|
||
})
|
||
except Exception as e:
|
||
self.logger.error(f"批量请求异常: {e}")
|
||
results.append({
|
||
"prompt": prompt,
|
||
"response": None,
|
||
"success": False,
|
||
"error": str(e)
|
||
})
|
||
|
||
return results
|
||
|
||
def health_check(self) -> bool:
|
||
"""健康检查"""
|
||
try:
|
||
response = self._make_request("GET", "/api/tags")
|
||
return response.status_code == 200
|
||
except Exception as e:
|
||
self.logger.error(f"健康检查失败: {e}")
|
||
return False
|
||
|
||
def pull_model(self, model_name: str) -> bool:
|
||
"""拉取模型"""
|
||
data = {"name": model_name}
|
||
|
||
try:
|
||
response = self._make_request(
|
||
"POST",
|
||
"/api/pull",
|
||
json=data,
|
||
stream=True
|
||
)
|
||
|
||
print(f"正在拉取模型 {model_name}...")
|
||
for chunk in self._handle_stream_response(response):
|
||
if "status" in chunk:
|
||
print(f"状态: {chunk['status']}")
|
||
if chunk.get("done", False):
|
||
print("模型拉取完成!")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"拉取模型失败: {e}")
|
||
return False
|
||
|
||
def is_model_available(self, model_name: str) -> bool:
|
||
"""检查模型是否可用"""
|
||
models = self.list_models()
|
||
if "models" in models:
|
||
return any(model["name"].startswith(model_name) for model in models["models"])
|
||
return False
|
||
|
||
def __enter__(self):
|
||
"""上下文管理器入口"""
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
"""上下文管理器退出"""
|
||
self._executor.shutdown(wait=True)
|
||
self.session.close()
|
||
|
||
|
||
# 使用示例
|
||
def main():
|
||
# 创建客户端,配置重试和并发参数
|
||
client = OllamaClient(
|
||
max_retries=3,
|
||
retry_delay=1.0,
|
||
max_concurrent_requests=3,
|
||
timeout=30,
|
||
enable_logging=True
|
||
)
|
||
|
||
# 健康检查
|
||
if not client.health_check():
|
||
print("Ollama服务不可用,请检查服务状态")
|
||
return
|
||
|
||
# 检查服务是否可用
|
||
try:
|
||
models = client.list_models()
|
||
print("Ollama服务连接成功!")
|
||
print(f"可用模型: {[m['name'] for m in models.get('models', [])]}")
|
||
except Exception as e:
|
||
print(f"无法连接到Ollama服务: {e}")
|
||
return
|
||
|
||
model_name = "Qwen3-8B:latest" # 根据你的模型调整
|
||
|
||
# 检查模型是否存在,不存在则拉取
|
||
if not client.is_model_available(model_name):
|
||
print(f"模型 {model_name} 不存在,正在拉取...")
|
||
if not client.pull_model(model_name):
|
||
print("模型拉取失败")
|
||
return
|
||
|
||
# 示例1: 处理500错误的单个请求
|
||
# print("\n=== 单个请求(含错误处理) ===")
|
||
# try:
|
||
# response = client.generate(
|
||
# model=model_name,
|
||
# prompt="请用中文简单介绍一下人工智能。",
|
||
# no_thinking=True,
|
||
# temperature=0.7,
|
||
# max_tokens=200
|
||
# )
|
||
|
||
# if response:
|
||
# print(f"回答: {response.get('response', '')}")
|
||
# else:
|
||
# print("请求失败")
|
||
# except Exception as e:
|
||
# print(f"请求异常: {e}")
|
||
|
||
# 示例2: 流式输出
|
||
print("\n=== 流式输出 ===")
|
||
try:
|
||
stream_response = client.generate(
|
||
model=model_name,
|
||
prompt="请讲一个简短的故事。",
|
||
stream=True,
|
||
no_thinking=True,
|
||
enable_thinking = False,
|
||
temperature=0.8,
|
||
max_tokens=200
|
||
)
|
||
|
||
print("流式回答: ", end="")
|
||
for chunk in stream_response:
|
||
if "response" in chunk:
|
||
print(chunk["response"], end="", flush=True)
|
||
if chunk.get("done", False):
|
||
print("\n")
|
||
break
|
||
except Exception as e:
|
||
print(f"流式输出异常: {e}")
|
||
|
||
# 示例3: 对话模式
|
||
print("\n=== 对话模式 ===")
|
||
try:
|
||
messages = [
|
||
{"role" : "user", "content" : "请介绍一下昊天"},
|
||
]
|
||
|
||
chat_response = client.chat(
|
||
model=model_name,
|
||
messages=messages,
|
||
# no_thinking=True,
|
||
# # no_thinking=True,
|
||
# enable_thinking = True,
|
||
temperature=0.7,
|
||
max_tokens=200
|
||
)
|
||
|
||
if chat_response:
|
||
print(f"AI回答: {chat_response.get('message', {}).get('content', '')}")
|
||
except Exception as e:
|
||
print(f"对话异常: {e}")
|
||
|
||
# 示例4: 批量请求测试
|
||
print("\n=== 批量请求测试 ===")
|
||
prompts = [
|
||
"什么是机器学习?",
|
||
"什么是深度学习?",
|
||
"什么是自然语言处理?",
|
||
"什么是计算机视觉?",
|
||
"什么是强化学习?"
|
||
]
|
||
|
||
try:
|
||
batch_results = client.batch_generate(
|
||
model=model_name,
|
||
prompts=prompts,
|
||
temperature=0.7,
|
||
max_tokens=100
|
||
)
|
||
|
||
for i, result in enumerate(batch_results):
|
||
if result["success"]:
|
||
print(f"请求 {i+1} 成功: {result['response'].get('response', '')[:50]}...")
|
||
else:
|
||
print(f"请求 {i+1} 失败: {result.get('error', '未知错误')}")
|
||
|
||
except Exception as e:
|
||
print(f"批量请求异常: {e}")
|
||
|
||
# 示例5: 并发压力测试
|
||
print("\n=== 并发压力测试 ===")
|
||
|
||
def stress_test():
|
||
import concurrent.futures
|
||
|
||
def single_request(i):
|
||
try:
|
||
response = client.generate(
|
||
model=model_name,
|
||
prompt=f"请简单回答:什么是人工智能?(请求{i})",
|
||
temperature=0.7,
|
||
max_tokens=50
|
||
)
|
||
return f"请求{i}: 成功" if response else f"请求{i}: 失败"
|
||
except Exception as e:
|
||
return f"请求{i}: 异常 - {e}"
|
||
|
||
# 发送10个并发请求
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
||
futures = [executor.submit(single_request, i) for i in range(10)]
|
||
|
||
for future in concurrent.futures.as_completed(futures):
|
||
try:
|
||
result = future.result()
|
||
print(result)
|
||
except Exception as e:
|
||
print(f"并发请求异常: {e}")
|
||
|
||
stress_test()
|
||
|
||
# 显示统计信息
|
||
print("\n=== 请求统计 ===")
|
||
stats = client.get_stats()
|
||
for key, value in stats.items():
|
||
print(f"{key}: {value}")
|
||
|
||
|
||
def demo_error_handling():
|
||
"""演示错误处理"""
|
||
print("\n=== 错误处理演示 ===")
|
||
|
||
# 模拟连接错误的客户端
|
||
client = OllamaClient(
|
||
base_url="http://localhost:99999", # 无效端口
|
||
max_retries=2,
|
||
retry_delay=0.5
|
||
)
|
||
|
||
try:
|
||
response = client.generate(
|
||
model="llama2",
|
||
prompt="测试请求",
|
||
temperature=0.7
|
||
)
|
||
print(f"意外成功: {response}")
|
||
except Exception as e:
|
||
print(f"预期的错误: {e}")
|
||
|
||
# 显示错误统计
|
||
stats = client.get_stats()
|
||
print(f"错误统计: {stats}")
|
||
|
||
|
||
def demo_context_manager():
|
||
"""演示上下文管理器使用"""
|
||
print("\n=== 上下文管理器演示 ===")
|
||
|
||
with OllamaClient(max_retries=2) as client:
|
||
if client.health_check():
|
||
print("服务健康检查通过")
|
||
|
||
# 使用客户端进行请求
|
||
models = client.list_models()
|
||
print(f"可用模型数量: {len(models.get('models', []))}")
|
||
else:
|
||
print("服务健康检查失败")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
demo_error_handling()
|
||
demo_context_manager() |