-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAdvancedEmbedder.py
More file actions
68 lines (50 loc) · 2.58 KB
/
AdvancedEmbedder.py
File metadata and controls
68 lines (50 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
class AdvancedEmbedder():
def __init__(self, model_name):
self.model = SentenceTransformer(model_name)
self.article_id = None
self.abstract_embedding = None
def abstract_embed(self, abstract):
self.abstract_embedding = self.model.encode([abstract])
# def calculate_similarity(self, phrase: str, art_id: str, abstract: str) -> float:
# """
# Compute cosine similarity between the embedding of a given phrase and the embedding of an abstract.
# Args:
# phrase (str): The phrase whose similarity to the abstract will be calculated.
# art_id (str): The ID of the abstract in `self.abstract_embeddings`.
# Returns:
# float: Cosine similarity score between the phrase and the abstract embedding.
# """
# if art_id != self.article_id:
# # print("Hold on, re-training the model")
# self.article_id = art_id
# self.abstract_embed(abstract)
# # print("No re-training needed")
# phrase_embedding = self.model.encode([phrase])
# paragraph_embedding = self.abstract_embedding
# # if paragraph_embedding.ndim == 1:
# # paragraph_embedding = np.expand_dims(paragraph_embedding, axis=0)
# similarity = cosine_similarity(phrase_embedding, paragraph_embedding)[0][0]
# return float(similarity)
def calculate_relevance(self, phrase: str, abstract: str) -> float:
"""
Compute cosine similarity between the embedding of a given phrase and the embedding of an abstract.
Args:
phrase (str): The phrase whose similarity to the abstract will be calculated.
art_id (str): The ID of the abstract in `self.abstract_embeddings`.
Returns:
float: Cosine similarity score between the phrase and the abstract embedding.
"""
# if art_id != self.article_id:
# # print("Hold on, re-training the model")
# self.article_id = art_id
# self.abstract_embed(abstract)
# print("No re-training needed")
phrase_embedding = self.model.encode([phrase])
paragraph_embedding = self.model.encode([abstract])
# if paragraph_embedding.ndim == 1:
# paragraph_embedding = np.expand_dims(paragraph_embedding, axis=0)
similarity = cosine_similarity(phrase_embedding, paragraph_embedding)[0][0]
return float(similarity)