import requests import json import time import threading from typing import Dict, Any, Iterator, Optional from functools import wraps from concurrent.futures import ThreadPoolExecutor, as_completed import logging class OllamaException(Exception): """Ollama自定义异常类""" pass class OllamaServerError(OllamaException): """服务器错误异常""" pass class OllamaRateLimitError(OllamaException): """请求限制异常""" pass class OllamaClient: def __init__(self, base_url: str = "http://localhost:11434", max_retries: int = 3, retry_delay: float = 1.0, max_concurrent_requests: int = 5, timeout: int = 30, enable_logging: bool = True): """ 初始化Ollama客户端 Args: base_url: Ollama服务的基础URL,默认为本地11434端口 max_retries: 最大重试次数 retry_delay: 重试间隔(秒) max_concurrent_requests: 最大并发请求数 timeout: 请求超时时间(秒) enable_logging: 是否启用日志 """ self.base_url = base_url self.max_retries = max_retries self.retry_delay = retry_delay self.max_concurrent_requests = max_concurrent_requests self.timeout = timeout # 设置日志 if enable_logging: logging.basicConfig(level=logging.INFO) self.logger = logging.getLogger(__name__) else: self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.CRITICAL) # 创建会话 self.session = requests.Session() self.session.timeout = timeout # 并发控制 self._request_lock = threading.Lock() self._active_requests = 0 self._executor = ThreadPoolExecutor(max_workers=max_concurrent_requests) # 请求统计 self._request_count = 0 self._success_count = 0 self._error_count = 0 self._last_request_time = 0 def _wait_for_slot(self): """等待可用的请求槽位""" while self._active_requests >= self.max_concurrent_requests: time.sleep(0.1) def _acquire_request_slot(self): """获取请求槽位""" with self._request_lock: self._wait_for_slot() self._active_requests += 1 self._request_count += 1 def _release_request_slot(self): """释放请求槽位""" with self._request_lock: self._active_requests = max(0, self._active_requests - 1) def _handle_response_error(self, response, attempt: int): """处理响应错误""" if response.status_code == 500: if attempt < self.max_retries: wait_time = self.retry_delay * (2 ** attempt) # 指数退避 self.logger.warning(f"服务器错误500,第{attempt + 1}次重试,等待{wait_time:.1f}秒...") time.sleep(wait_time) return True # 需要重试 else: raise OllamaServerError(f"服务器错误500,已达到最大重试次数{self.max_retries}") elif response.status_code == 429: # 速率限制 if attempt < self.max_retries: wait_time = self.retry_delay * (2 ** attempt) self.logger.warning(f"请求过于频繁(429),第{attempt + 1}次重试,等待{wait_time:.1f}秒...") time.sleep(wait_time) return True else: raise OllamaRateLimitError(f"请求过于频繁,已达到最大重试次数{self.max_retries}") elif response.status_code == 503: # 服务不可用 if attempt < self.max_retries: wait_time = self.retry_delay * (2 ** attempt) self.logger.warning(f"服务不可用(503),第{attempt + 1}次重试,等待{wait_time:.1f}秒...") time.sleep(wait_time) return True else: raise OllamaServerError(f"服务不可用,已达到最大重试次数{self.max_retries}") else: # 其他HTTP错误 response.raise_for_status() return False def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response: """发送请求的通用方法,包含重试逻辑""" url = f"{self.base_url}{endpoint}" for attempt in range(self.max_retries + 1): try: self._acquire_request_slot() # 速率限制检查 current_time = time.time() if current_time - self._last_request_time < 0.1: # 最小间隔100ms time.sleep(0.1) self._last_request_time = current_time self.logger.debug(f"发送请求: {method} {url}, 尝试 {attempt + 1}/{self.max_retries + 1}") if method.upper() == "GET": response = self.session.get(url, **kwargs) else: response = self.session.post(url, **kwargs) # 检查响应状态 if response.status_code == 200: self._success_count += 1 return response else: # 处理错误响应 should_retry = self._handle_response_error(response, attempt) if not should_retry: break except (requests.exceptions.ConnectTimeout, requests.exceptions.ReadTimeout, requests.exceptions.ConnectionError) as e: if attempt < self.max_retries: wait_time = self.retry_delay * (2 ** attempt) self.logger.warning(f"网络错误: {e}, 第{attempt + 1}次重试,等待{wait_time:.1f}秒...") time.sleep(wait_time) else: self._error_count += 1 raise OllamaException(f"网络连接失败: {e}") except requests.exceptions.RequestException as e: self._error_count += 1 raise OllamaException(f"请求失败: {e}") finally: self._release_request_slot() # 如果所有重试都失败了 self._error_count += 1 raise OllamaException("所有重试都失败了") def get_stats(self) -> Dict[str, Any]: """获取请求统计信息""" return { "total_requests": self._request_count, "successful_requests": self._success_count, "failed_requests": self._error_count, "active_requests": self._active_requests, "success_rate": self._success_count / max(self._request_count, 1) * 100 } def list_models(self) -> Dict[str, Any]: """获取已安装的模型列表""" try: response = self._make_request("GET", "/api/tags") return response.json() except Exception as e: self.logger.error(f"获取模型列表失败: {e}") return {} def generate(self, model: str, prompt: str, stream: bool = False, **kwargs) -> Any: """ 生成文本 Args: model: 模型名称 prompt: 输入提示 stream: 是否流式输出 **kwargs: 其他参数如temperature, top_p等 """ data = { "model": model, "prompt": prompt, "stream": stream, **kwargs } try: response = self._make_request( "POST", "/api/generate", json=data, stream=stream ) if stream: return self._handle_stream_response(response) else: return response.json() except Exception as e: self.logger.error(f"生成请求失败: {e}") return None def chat(self, model: str, messages: list, stream: bool = False, **kwargs) -> Any: """ 对话模式 Args: model: 模型名称 messages: 消息列表,格式为[{"role": "user", "content": "..."}] stream: 是否流式输出 **kwargs: 其他参数 """ data = { "model": model, "messages": messages, "stream": stream, **kwargs } try: response = self._make_request( "POST", "/api/chat", json=data, stream=stream ) if stream: return self._handle_stream_response(response) else: return response.json() except Exception as e: self.logger.error(f"对话请求失败: {e}") return None def _handle_stream_response(self, response) -> Iterator[Dict[str, Any]]: """处理流式响应""" for line in response.iter_lines(): if line: try: yield json.loads(line.decode('utf-8')) except json.JSONDecodeError: continue def batch_generate(self, model: str, prompts: list, **kwargs) -> list: """ 批量生成文本 Args: model: 模型名称 prompts: 提示列表 **kwargs: 其他参数 """ results = [] def generate_single(prompt): try: return self.generate(model, prompt, stream=False, **kwargs) except Exception as e: self.logger.error(f"批量生成失败 - 提示: {prompt[:50]}..., 错误: {e}") return None # 使用线程池执行批量请求 with ThreadPoolExecutor(max_workers=self.max_concurrent_requests) as executor: future_to_prompt = {executor.submit(generate_single, prompt): prompt for prompt in prompts} for future in as_completed(future_to_prompt): prompt = future_to_prompt[future] try: result = future.result() results.append({ "prompt": prompt, "response": result, "success": result is not None }) except Exception as e: self.logger.error(f"批量请求异常: {e}") results.append({ "prompt": prompt, "response": None, "success": False, "error": str(e) }) return results def health_check(self) -> bool: """健康检查""" try: response = self._make_request("GET", "/api/tags") return response.status_code == 200 except Exception as e: self.logger.error(f"健康检查失败: {e}") return False def pull_model(self, model_name: str) -> bool: """拉取模型""" data = {"name": model_name} try: response = self._make_request( "POST", "/api/pull", json=data, stream=True ) print(f"正在拉取模型 {model_name}...") for chunk in self._handle_stream_response(response): if "status" in chunk: print(f"状态: {chunk['status']}") if chunk.get("done", False): print("模型拉取完成!") return True except Exception as e: self.logger.error(f"拉取模型失败: {e}") return False def is_model_available(self, model_name: str) -> bool: """检查模型是否可用""" models = self.list_models() if "models" in models: return any(model["name"].startswith(model_name) for model in models["models"]) return False def __enter__(self): """上下文管理器入口""" return self def __exit__(self, exc_type, exc_val, exc_tb): """上下文管理器退出""" self._executor.shutdown(wait=True) self.session.close() # 使用示例 def main(): # 创建客户端,配置重试和并发参数 client = OllamaClient( max_retries=3, retry_delay=1.0, max_concurrent_requests=3, timeout=30, enable_logging=True ) # 健康检查 if not client.health_check(): print("Ollama服务不可用,请检查服务状态") return # 检查服务是否可用 try: models = client.list_models() print("Ollama服务连接成功!") print(f"可用模型: {[m['name'] for m in models.get('models', [])]}") except Exception as e: print(f"无法连接到Ollama服务: {e}") return model_name = "Qwen3-8B:latest" # 根据你的模型调整 # 检查模型是否存在,不存在则拉取 if not client.is_model_available(model_name): print(f"模型 {model_name} 不存在,正在拉取...") if not client.pull_model(model_name): print("模型拉取失败") return # 示例1: 处理500错误的单个请求 # print("\n=== 单个请求(含错误处理) ===") # try: # response = client.generate( # model=model_name, # prompt="请用中文简单介绍一下人工智能。", # no_thinking=True, # temperature=0.7, # max_tokens=200 # ) # if response: # print(f"回答: {response.get('response', '')}") # else: # print("请求失败") # except Exception as e: # print(f"请求异常: {e}") # 示例2: 流式输出 print("\n=== 流式输出 ===") try: stream_response = client.generate( model=model_name, prompt="请讲一个简短的故事。", stream=True, no_thinking=True, enable_thinking = False, temperature=0.8, max_tokens=200 ) print("流式回答: ", end="") for chunk in stream_response: if "response" in chunk: print(chunk["response"], end="", flush=True) if chunk.get("done", False): print("\n") break except Exception as e: print(f"流式输出异常: {e}") # 示例3: 对话模式 print("\n=== 对话模式 ===") try: messages = [ {"role" : "user", "content" : "请介绍一下昊天"}, ] chat_response = client.chat( model=model_name, messages=messages, # no_thinking=True, # # no_thinking=True, # enable_thinking = True, temperature=0.7, max_tokens=200 ) if chat_response: print(f"AI回答: {chat_response.get('message', {}).get('content', '')}") except Exception as e: print(f"对话异常: {e}") # 示例4: 批量请求测试 print("\n=== 批量请求测试 ===") prompts = [ "什么是机器学习?", "什么是深度学习?", "什么是自然语言处理?", "什么是计算机视觉?", "什么是强化学习?" ] try: batch_results = client.batch_generate( model=model_name, prompts=prompts, temperature=0.7, max_tokens=100 ) for i, result in enumerate(batch_results): if result["success"]: print(f"请求 {i+1} 成功: {result['response'].get('response', '')[:50]}...") else: print(f"请求 {i+1} 失败: {result.get('error', '未知错误')}") except Exception as e: print(f"批量请求异常: {e}") # 示例5: 并发压力测试 print("\n=== 并发压力测试 ===") def stress_test(): import concurrent.futures def single_request(i): try: response = client.generate( model=model_name, prompt=f"请简单回答:什么是人工智能?(请求{i})", temperature=0.7, max_tokens=50 ) return f"请求{i}: 成功" if response else f"请求{i}: 失败" except Exception as e: return f"请求{i}: 异常 - {e}" # 发送10个并发请求 with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: futures = [executor.submit(single_request, i) for i in range(10)] for future in concurrent.futures.as_completed(futures): try: result = future.result() print(result) except Exception as e: print(f"并发请求异常: {e}") stress_test() # 显示统计信息 print("\n=== 请求统计 ===") stats = client.get_stats() for key, value in stats.items(): print(f"{key}: {value}") def demo_error_handling(): """演示错误处理""" print("\n=== 错误处理演示 ===") # 模拟连接错误的客户端 client = OllamaClient( base_url="http://localhost:99999", # 无效端口 max_retries=2, retry_delay=0.5 ) try: response = client.generate( model="llama2", prompt="测试请求", temperature=0.7 ) print(f"意外成功: {response}") except Exception as e: print(f"预期的错误: {e}") # 显示错误统计 stats = client.get_stats() print(f"错误统计: {stats}") def demo_context_manager(): """演示上下文管理器使用""" print("\n=== 上下文管理器演示 ===") with OllamaClient(max_retries=2) as client: if client.health_check(): print("服务健康检查通过") # 使用客户端进行请求 models = client.list_models() print(f"可用模型数量: {len(models.get('models', []))}") else: print("服务健康检查失败") if __name__ == "__main__": main() demo_error_handling() demo_context_manager()