unsloth/005ollamaApi.py

226 lines
6.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import json
import time
from typing import Dict, Any, Iterator
class OllamaClient:
def __init__(self, base_url: str = "http://localhost:11434"):
"""
初始化Ollama客户端
Args:
base_url: Ollama服务的基础URL默认为本地11434端口
"""
self.base_url = base_url
self.session = requests.Session()
def list_models(self) -> Dict[str, Any]:
"""获取已安装的模型列表"""
try:
response = self.session.get(f"{self.base_url}/api/tags")
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
print(f"获取模型列表失败: {e}")
return {}
def generate(self, model: str, prompt: str, stream: bool = False, **kwargs) -> Any:
"""
生成文本
Args:
model: 模型名称
prompt: 输入提示
stream: 是否流式输出
**kwargs: 其他参数如temperature, top_p等
"""
data = {
"model": model,
"prompt": prompt,
"stream": stream,
**kwargs
}
try:
response = self.session.post(
f"{self.base_url}/api/generate",
json=data,
stream=stream
)
# 检查请求是否成功
response.raise_for_status()
if stream:
return self._handle_stream_response(response)
else:
return response.json()
except requests.exceptions.RequestException as e:
print(f"生成请求失败: {e}")
return None
def chat(self, model: str, messages: list, stream: bool = False, **kwargs) -> Any:
"""
对话模式
Args:
model: 模型名称
messages: 消息列表,格式为[{"role": "user", "content": "..."}]
stream: 是否流式输出
**kwargs: 其他参数
"""
data = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs
}
try:
response = self.session.post(
f"{self.base_url}/api/chat",
json=data,
stream=stream
)
response.raise_for_status()
if stream:
return self._handle_stream_response(response)
else:
return response.json()
except requests.exceptions.RequestException as e:
print(f"对话请求失败: {e}")
return None
def _handle_stream_response(self, response) -> Iterator[Dict[str, Any]]:
"""处理流式响应"""
for line in response.iter_lines():
if line:
try:
yield json.loads(line.decode('utf-8'))
except json.JSONDecodeError:
continue
def pull_model(self, model_name: str) -> bool:
"""拉取模型"""
data = {"name": model_name}
try:
response = self.session.post(
f"{self.base_url}/api/pull",
json=data,
stream=True
)
response.raise_for_status()
print(f"正在拉取模型 {model_name}...")
for chunk in self._handle_stream_response(response):
if "status" in chunk:
print(f"状态: {chunk['status']}")
if chunk.get("done", False):
print("模型拉取完成!")
return True
except requests.exceptions.RequestException as e:
print(f"拉取模型失败: {e}")
return False
def is_model_available(self, model_name: str) -> bool:
"""检查模型是否可用"""
models = self.list_models()
if "models" in models:
return any(model["name"].startswith(model_name) for model in models["models"])
return False
# 使用示例
def main():
# 创建客户端
client = OllamaClient()
# 检查服务是否可用
try:
models = client.list_models()
print("Ollama服务连接成功!")
print(f"可用模型: {[m['name'] for m in models.get('models', [])]}")
except Exception as e:
print(f"无法连接到Ollama服务: {e}")
return
model_name = "Qwen3-8B:latest" # 根据你的模型调整
# 检查模型是否存在,不存在则拉取
if not client.is_model_available(model_name):
print(f"模型 {model_name} 不存在,正在拉取...")
if not client.pull_model(model_name):
print("模型拉取失败")
return
# 示例1: 简单文本生成
print("\n=== 简单文本生成 ===")
response = client.generate(
model=model_name,
prompt="请用中文简单介绍一下人工智能。",
temperature=0.7,
max_tokens=200
)
if response:
print(f"回答: {response.get('response', '')}")
# 示例2: 流式输出
print("\n=== 流式输出 ===")
stream_response = client.generate(
model=model_name,
prompt="请讲一个简短的故事。",
stream=True,
temperature=0.8,
max_tokens=200
)
print("流式回答: ", end="")
for chunk in stream_response:
if "response" in chunk:
print(chunk["response"], end="", flush=True)
if chunk.get("done", False):
print("\n")
break
# 示例3: 对话模式
print("\n=== 对话模式 ===")
messages = [
{"role": "user", "content": "你好,请介绍一下自己。"},
]
chat_response = client.chat(
model=model_name,
messages=messages,
temperature=0.7
)
if chat_response:
print(f"AI回答: {chat_response.get('message', {}).get('content', '')}")
# 示例4: 多轮对话
print("\n=== 多轮对话 ===")
conversation = [
{"role": "user", "content": "什么是机器学习?"}
]
# 第一轮
response1 = client.chat(model=model_name, messages=conversation)
if response1 and "message" in response1:
ai_response = response1["message"]["content"]
print(f"AI: {ai_response}")
conversation.append({"role": "assistant", "content": ai_response})
# 第二轮
conversation.append({"role": "user", "content": "能举个具体例子吗?"})
response2 = client.chat(model=model_name, messages=conversation)
if response2 and "message" in response2:
print(f"AI: {response2['message']['content']}")
if __name__ == "__main__":
main()