226 lines
6.8 KiB
Python
226 lines
6.8 KiB
Python
import requests
|
||
import json
|
||
import time
|
||
from typing import Dict, Any, Iterator
|
||
|
||
class OllamaClient:
|
||
def __init__(self, base_url: str = "http://localhost:11434"):
|
||
"""
|
||
初始化Ollama客户端
|
||
|
||
Args:
|
||
base_url: Ollama服务的基础URL,默认为本地11434端口
|
||
"""
|
||
self.base_url = base_url
|
||
self.session = requests.Session()
|
||
|
||
def list_models(self) -> Dict[str, Any]:
|
||
"""获取已安装的模型列表"""
|
||
try:
|
||
response = self.session.get(f"{self.base_url}/api/tags")
|
||
response.raise_for_status()
|
||
return response.json()
|
||
except requests.exceptions.RequestException as e:
|
||
print(f"获取模型列表失败: {e}")
|
||
return {}
|
||
|
||
def generate(self, model: str, prompt: str, stream: bool = False, **kwargs) -> Any:
|
||
"""
|
||
生成文本
|
||
|
||
Args:
|
||
model: 模型名称
|
||
prompt: 输入提示
|
||
stream: 是否流式输出
|
||
**kwargs: 其他参数如temperature, top_p等
|
||
"""
|
||
data = {
|
||
"model": model,
|
||
"prompt": prompt,
|
||
"stream": stream,
|
||
**kwargs
|
||
}
|
||
|
||
try:
|
||
response = self.session.post(
|
||
f"{self.base_url}/api/generate",
|
||
json=data,
|
||
stream=stream
|
||
)
|
||
# 检查请求是否成功
|
||
response.raise_for_status()
|
||
|
||
if stream:
|
||
return self._handle_stream_response(response)
|
||
else:
|
||
return response.json()
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
print(f"生成请求失败: {e}")
|
||
return None
|
||
|
||
def chat(self, model: str, messages: list, stream: bool = False, **kwargs) -> Any:
|
||
"""
|
||
对话模式
|
||
|
||
Args:
|
||
model: 模型名称
|
||
messages: 消息列表,格式为[{"role": "user", "content": "..."}]
|
||
stream: 是否流式输出
|
||
**kwargs: 其他参数
|
||
"""
|
||
data = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"stream": stream,
|
||
**kwargs
|
||
}
|
||
|
||
try:
|
||
response = self.session.post(
|
||
f"{self.base_url}/api/chat",
|
||
json=data,
|
||
stream=stream
|
||
)
|
||
response.raise_for_status()
|
||
|
||
if stream:
|
||
return self._handle_stream_response(response)
|
||
else:
|
||
return response.json()
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
print(f"对话请求失败: {e}")
|
||
return None
|
||
|
||
def _handle_stream_response(self, response) -> Iterator[Dict[str, Any]]:
|
||
"""处理流式响应"""
|
||
for line in response.iter_lines():
|
||
if line:
|
||
try:
|
||
yield json.loads(line.decode('utf-8'))
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
def pull_model(self, model_name: str) -> bool:
|
||
"""拉取模型"""
|
||
data = {"name": model_name}
|
||
|
||
try:
|
||
response = self.session.post(
|
||
f"{self.base_url}/api/pull",
|
||
json=data,
|
||
stream=True
|
||
)
|
||
response.raise_for_status()
|
||
|
||
print(f"正在拉取模型 {model_name}...")
|
||
for chunk in self._handle_stream_response(response):
|
||
if "status" in chunk:
|
||
print(f"状态: {chunk['status']}")
|
||
if chunk.get("done", False):
|
||
print("模型拉取完成!")
|
||
return True
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
print(f"拉取模型失败: {e}")
|
||
return False
|
||
|
||
def is_model_available(self, model_name: str) -> bool:
|
||
"""检查模型是否可用"""
|
||
models = self.list_models()
|
||
if "models" in models:
|
||
return any(model["name"].startswith(model_name) for model in models["models"])
|
||
return False
|
||
|
||
|
||
# 使用示例
|
||
def main():
|
||
# 创建客户端
|
||
client = OllamaClient()
|
||
|
||
# 检查服务是否可用
|
||
try:
|
||
models = client.list_models()
|
||
print("Ollama服务连接成功!")
|
||
print(f"可用模型: {[m['name'] for m in models.get('models', [])]}")
|
||
except Exception as e:
|
||
print(f"无法连接到Ollama服务: {e}")
|
||
return
|
||
|
||
model_name = "Qwen3-8B:latest" # 根据你的模型调整
|
||
|
||
# 检查模型是否存在,不存在则拉取
|
||
if not client.is_model_available(model_name):
|
||
print(f"模型 {model_name} 不存在,正在拉取...")
|
||
if not client.pull_model(model_name):
|
||
print("模型拉取失败")
|
||
return
|
||
|
||
# 示例1: 简单文本生成
|
||
print("\n=== 简单文本生成 ===")
|
||
response = client.generate(
|
||
model=model_name,
|
||
prompt="请用中文简单介绍一下人工智能。",
|
||
temperature=0.7,
|
||
max_tokens=200
|
||
)
|
||
|
||
if response:
|
||
print(f"回答: {response.get('response', '')}")
|
||
|
||
# 示例2: 流式输出
|
||
print("\n=== 流式输出 ===")
|
||
stream_response = client.generate(
|
||
model=model_name,
|
||
prompt="请讲一个简短的故事。",
|
||
stream=True,
|
||
temperature=0.8,
|
||
max_tokens=200
|
||
)
|
||
|
||
print("流式回答: ", end="")
|
||
for chunk in stream_response:
|
||
if "response" in chunk:
|
||
print(chunk["response"], end="", flush=True)
|
||
if chunk.get("done", False):
|
||
print("\n")
|
||
break
|
||
|
||
# 示例3: 对话模式
|
||
print("\n=== 对话模式 ===")
|
||
messages = [
|
||
{"role": "user", "content": "你好,请介绍一下自己。"},
|
||
]
|
||
|
||
chat_response = client.chat(
|
||
model=model_name,
|
||
messages=messages,
|
||
temperature=0.7
|
||
)
|
||
|
||
if chat_response:
|
||
print(f"AI回答: {chat_response.get('message', {}).get('content', '')}")
|
||
|
||
# 示例4: 多轮对话
|
||
print("\n=== 多轮对话 ===")
|
||
conversation = [
|
||
{"role": "user", "content": "什么是机器学习?"}
|
||
]
|
||
|
||
# 第一轮
|
||
response1 = client.chat(model=model_name, messages=conversation)
|
||
if response1 and "message" in response1:
|
||
ai_response = response1["message"]["content"]
|
||
print(f"AI: {ai_response}")
|
||
conversation.append({"role": "assistant", "content": ai_response})
|
||
|
||
# 第二轮
|
||
conversation.append({"role": "user", "content": "能举个具体例子吗?"})
|
||
response2 = client.chat(model=model_name, messages=conversation)
|
||
if response2 and "message" in response2:
|
||
print(f"AI: {response2['message']['content']}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |