diff --git a/free_ask_internet.py b/free_ask_internet.py index 62d9f29..875a2ea 100644 --- a/free_ask_internet.py +++ b/free_ask_internet.py @@ -1,24 +1,21 @@ # -*- coding: utf-8 -*- - -import json -import os from pprint import pprint import requests import trafilatura from trafilatura import bare_extraction from concurrent.futures import ThreadPoolExecutor import concurrent -import requests -import openai -import time -from datetime import datetime +import openai from urllib.parse import urlparse import tldextract -import platform import urllib.parse +from typing import List, Dict, Any, Optional, Generator -def extract_url_content(url): +UrlStr = str +Query = str + +def extract_url_content(url: UrlStr) -> Dict[str, Any]: downloaded = trafilatura.fetch_url(url) content = trafilatura.extract(downloaded) @@ -26,8 +23,7 @@ def extract_url_content(url): - -def search_web_ref(query:str, debug=False): +def search_web_ref(query: Query, debug: Optional[bool] = False) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: content_list = [] @@ -56,7 +52,7 @@ def search_web_ref(query:str, debug=False): if url: url_parsed = urlparse(url) domain = url_parsed.netloc - icon_url = url_parsed.scheme + '://' + url_parsed.netloc + '/favicon.ico' + icon_url = f'{url_parsed.scheme}://{domain}/favicon.ico' site_name = tldextract.extract(url).domain conv_links.append({ @@ -94,13 +90,15 @@ def search_web_ref(query:str, debug=False): print("URL: {}".format(url)) print("=================") - return conv_links,content_list + + return conv_links, content_list except Exception as ex: raise ex -def gen_prompt(question,content_list, lang="zh-CN", context_length_limit=11000,debug=False): - +def gen_prompt(question: Query, lang: Optional[str] = 'zh-CN', content_list: List[Dict[str, Any]], + context_length_limit: Optional[int] = 11000, debug: Optional[bool] = False) -> str: + limit_len = (context_length_limit - 2000) if len(question) > limit_len: question = question[0:limit_len] @@ -168,7 +166,7 @@ def gen_prompt(question,content_list, lang="zh-CN", context_length_limit=11000,d return prompts -def chat(prompt, model:str,llm_auth_token:str,llm_base_url:str,using_custom_llm=False,stream=True, debug=False): +def chat(prompt, model:str, llm_auth_token:str, llm_base_url:str, using_custom_llm=False, stream=True, debug=False): openai.base_url = "http://127.0.0.1:3040/v1/" if model == "gpt3.5": @@ -191,7 +189,6 @@ def chat(prompt, model:str,llm_auth_token:str,llm_base_url:str,using_custom_llm= openai.base_url = llm_base_url openai.api_key = "CUSTOM" - total_content = "" for chunk in openai.chat.completions.create( model=model, @@ -200,21 +197,18 @@ def chat(prompt, model:str,llm_auth_token:str,llm_base_url:str,using_custom_llm= "content": prompt }], stream=True, - max_tokens=1024,temperature=0.2 + max_tokens=1024, + temperature=0.2 ): stream_resp = chunk.dict() token = stream_resp["choices"][0]["delta"].get("content", "") if token: - total_content += token yield token if debug: print(total_content) - - - -def ask_internet(query:str, debug=False): +def ask_internet(query: Query, debug: Optional[bool] = False) -> Generator[str, None, None]: content_list = search_web_ref(query,debug=debug) if debug: