Skip to content

Commit 43401c8

Browse files
authored
Merge pull request #48 from VinciGit00/llama_new_models
Llama new models
2 parents 958b2a4 + 992e7f8 commit 43401c8

File tree

7 files changed

+31
-19
lines changed

7 files changed

+31
-19
lines changed

scrapegraphai/graphs/abstract_graph.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI
77
from ..helpers import models_tokens
88

9+
910
class AbstractGraph(ABC):
1011
"""
1112
Abstract class representing a generic graph-based tool.
@@ -19,7 +20,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
1920
self.source = source
2021
self.config = config
2122
self.llm_model = self._create_llm(config["llm"])
22-
self.embedder_model = None if "embeddings" not in config else self._create_llm(config["embeddings"])
23+
self.embedder_model = None if "embeddings" not in config else self._create_llm(
24+
config["embeddings"])
2325
self.graph = self._create_graph()
2426

2527
def _create_llm(self, llm_config: dict):
@@ -39,7 +41,7 @@ def _create_llm(self, llm_config: dict):
3941
except KeyError:
4042
raise ValueError("Model not supported")
4143
return OpenAI(llm_params)
42-
44+
4345
elif "azure" in llm_params["model"]:
4446
# take the model after the last dash
4547
llm_params["model"] = llm_params["model"].split("/")[-1]
@@ -48,23 +50,30 @@ def _create_llm(self, llm_config: dict):
4850
except KeyError:
4951
raise ValueError("Model not supported")
5052
return AzureOpenAI(llm_params)
51-
53+
5254
elif "gemini" in llm_params["model"]:
5355
try:
5456
self.model_token = models_tokens["gemini"][llm_params["model"]]
5557
except KeyError:
5658
raise ValueError("Model not supported")
5759
return Gemini(llm_params)
58-
60+
5961
elif "ollama" in llm_params["model"]:
60-
# take the model after the last dash
62+
"""
63+
Avaiable models:
64+
- llama2
65+
- mistral
66+
- codellama
67+
- dolphin-mixtral
68+
- mistral-openorca
69+
"""
6170
llm_params["model"] = llm_params["model"].split("/")[-1]
6271
try:
6372
self.model_token = models_tokens["ollama"][llm_params["model"]]
6473
except KeyError:
6574
raise ValueError("Model not supported")
6675
return Ollama(llm_params)
67-
76+
6877
else:
6978
raise ValueError("Model not supported")
7079

scrapegraphai/graphs/smart_scraper_graph.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
Module for creating the smart scraper
33
"""
4-
from ..models import OpenAI, Gemini
54
from .base_graph import BaseGraph
65
from ..nodes import (
76
FetchNode,

scrapegraphai/graphs/speech_graph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Module for converting text to speach
33
"""
44
from scrapegraphai.utils.save_audio_from_bytes import save_audio_from_bytes
5-
from ..models import OpenAI, Gemini, OpenAITextToSpeech
5+
from ..models import OpenAITextToSpeech
66
from .base_graph import BaseGraph
77
from ..nodes import (
88
FetchNode,
@@ -27,7 +27,7 @@ def __init__(self, prompt: str, source: str, config: dict):
2727
super().__init__(prompt, config, source)
2828

2929
self.input_key = "url" if source.startswith("http") else "local_dir"
30-
30+
3131
def _create_graph(self):
3232
"""
3333
Creates the graph of nodes representing the workflow for web scraping and summarization.

scrapegraphai/helpers/models_tokens.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
"gemini-pro": 128000,
2222
},
2323

24-
"ollama":{
24+
"ollama": {
2525
"llama2": 4096,
2626
"mistral": 8192,
27+
"codellama": 16000,
28+
"dolphin-mixtral": 32000,
29+
"mistral-openorca": 32000,
2730
}
28-
2931
}

scrapegraphai/nodes/fetch_node.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def execute(self, state):
6767
# Fetching data from the state based on the input keys
6868
input_data = [state[key] for key in input_keys]
6969

70-
source = input_data[0]
71-
70+
source = input_data[0]
71+
7272
# if it is a local directory
7373
if not source.startswith("http"):
7474
document = [Document(page_content=source, metadata={

scrapegraphai/nodes/rag_node.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
99
from langchain_community.document_transformers import EmbeddingsRedundantFilter
1010
from langchain_community.vectorstores import FAISS
11-
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI
1211
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
12+
from ..models import OpenAI, Ollama, AzureOpenAI
1313
from langchain_community.embeddings import OllamaEmbeddings
1414
from .base_node import BaseNode
1515

@@ -86,16 +86,18 @@ def execute(self, state):
8686
embedding_model = self.embedder_model if self.embedder_model else self.llm_model
8787

8888
if isinstance(embedding_model, OpenAI):
89-
embeddings = OpenAIEmbeddings(api_key=embedding_model.openai_api_key)
89+
embeddings = OpenAIEmbeddings(
90+
api_key=embedding_model.openai_api_key)
9091
elif isinstance(embedding_model, AzureOpenAI):
9192
embeddings = AzureOpenAIEmbeddings()
9293
elif isinstance(embedding_model, Ollama):
9394
embeddings = OllamaEmbeddings()
9495
else:
9596
raise ValueError("Embedding Model missing or not supported")
96-
97-
retriever = FAISS.from_documents(chunked_docs, embeddings).as_retriever()
98-
97+
98+
retriever = FAISS.from_documents(
99+
chunked_docs, embeddings).as_retriever()
100+
99101
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
100102
# similarity_threshold could be set, now k=20
101103
relevant_filter = EmbeddingsFilter(embeddings=embeddings)

scrapegraphai/nodes/search_internet_node.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def execute(self, state):
9494
# Execute the chain to get the search query
9595
search_answer = search_prompt | self.llm_model | output_parser
9696
search_query = search_answer.invoke({"user_prompt": user_prompt})[0]
97-
97+
9898
print(f"Search Query: {search_query}")
9999
# TODO: handle multiple URLs
100100
answer = search_on_web(query=search_query, max_results=1)[0]

0 commit comments

Comments
 (0)