retriver

Retriver

retriever

Source code in src/racp/retriver.py
class Retriver():
    """retriever 

    """
    def __init__(self, config=None, database=None) -> None:
        """Initialize retriever using config and database

        Args:
            config (Config): configuration for the retriever.
            database (list): a list of Document objects to build the retriever from.
        """
        self.text_splitter = CharacterTextSplitter(chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap)
        self.build_embedding_model(config)
        if database is not None:
            self.build_retriver_from_database(database)
        else:
            raise ValueError('Please specify database')

    def build_embedding_model(self, config):
        """Initialize HuggingFaceEmbeddings

        Args:
            config (Config): configuration for the retriever.
        """
        model_kwargs = {'device': config.device}
        encode_kwargs = {'normalize_embeddings': config.normalize_embeddings}
        self.hf = HuggingFaceEmbeddings(model_name=config.model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)
    def build_retriver_from_database(self, database):
        """Build the retriever from the database

        Args:
            database (list): a list of Document objects to build the retriever from.
        """
        # TODO : remove k < 2000 
        # data = [i.to_Document() for k,i in enumerate(database) if k < 2000 ]
        data = [i.to_Document() for k,i in enumerate(database) ]
        print(f'Loaded {len(data)} documents using database ')
        documents = self.text_splitter.split_documents(data)
        ## check duplicate 
        ids = set([doc.metadata['source'] for doc in documents])
        print(len(ids),len(documents))
        # self.db = Chroma.from_documents(documents,self.hf)
        from time import time 
        t0 = time()
        store = LocalFileStore("./cache/")

        cached_embedder = CacheBackedEmbeddings.from_bytes_store(
            self.hf, store, namespace="test"
        )
        self.db = FAISS.from_documents(documents, cached_embedder)
        t1 = time()
        print("loading time ",t1-t0)
    def retrival(self, query, k=10):
        """Perform retrieval

        Args:
            query (str): the query to search for in the retriever.
            k (int): number of documents to return.

        Returns:
            list: a list of dictionaries containing information about the retrieved documents.
        """
        docs = self.db.similarity_search_with_relevance_scores(query,k=k*2)
        # 现在这个result 里面 arxiv id有重复,请你帮我去掉重复的
        result = [{'Papername':doc[0].metadata['title'],'arxiv_id':doc[0].metadata['source'],'quality':doc[0].metadata['quality'],'relevance':doc[1]} for doc in docs if doc[1]>0]
        # 如果你希望按照原始列表中的顺序保留其他字段,可以使用以下代码:
        arxivids = list(set([doc[0].metadata['source'] for doc in docs if doc[1] > 0]))

        unique_result = []
        for doc in docs:
            if doc[1] > 0:
                arxiv_id = doc[0].metadata['source']
                if arxiv_id  in arxivids:
                    # unique_result.append({'Papername': doc[0].metadata['title'], 'arxiv_id': doc[0].metadata['source'], 'quality': doc[0].metadata['quality'], 'relevance': doc[1]})
                    unique_result.append({'Papername': doc[0].metadata['title'], 'arxiv_id':doc[0].metadata['source'], 'relevance': doc[1]})
                    arxivids.remove(arxiv_id) 
        print(unique_result)
        return unique_result 

__init__(config=None, database=None)

Initialize retriever using config and database

Parameters:
  • config (Config, default: None ) –

    configuration for the retriever.

  • database (list, default: None ) –

    a list of Document objects to build the retriever from.

Source code in src/racp/retriver.py
def __init__(self, config=None, database=None) -> None:
    """Initialize retriever using config and database

    Args:
        config (Config): configuration for the retriever.
        database (list): a list of Document objects to build the retriever from.
    """
    self.text_splitter = CharacterTextSplitter(chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap)
    self.build_embedding_model(config)
    if database is not None:
        self.build_retriver_from_database(database)
    else:
        raise ValueError('Please specify database')

build_embedding_model(config)

Initialize HuggingFaceEmbeddings

Parameters:
  • config (Config) –

    configuration for the retriever.

Source code in src/racp/retriver.py
def build_embedding_model(self, config):
    """Initialize HuggingFaceEmbeddings

    Args:
        config (Config): configuration for the retriever.
    """
    model_kwargs = {'device': config.device}
    encode_kwargs = {'normalize_embeddings': config.normalize_embeddings}
    self.hf = HuggingFaceEmbeddings(model_name=config.model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)

build_retriver_from_database(database)

Build the retriever from the database

Parameters:
  • database (list) –

    a list of Document objects to build the retriever from.

Source code in src/racp/retriver.py
def build_retriver_from_database(self, database):
    """Build the retriever from the database

    Args:
        database (list): a list of Document objects to build the retriever from.
    """
    # TODO : remove k < 2000 
    # data = [i.to_Document() for k,i in enumerate(database) if k < 2000 ]
    data = [i.to_Document() for k,i in enumerate(database) ]
    print(f'Loaded {len(data)} documents using database ')
    documents = self.text_splitter.split_documents(data)
    ## check duplicate 
    ids = set([doc.metadata['source'] for doc in documents])
    print(len(ids),len(documents))
    # self.db = Chroma.from_documents(documents,self.hf)
    from time import time 
    t0 = time()
    store = LocalFileStore("./cache/")

    cached_embedder = CacheBackedEmbeddings.from_bytes_store(
        self.hf, store, namespace="test"
    )
    self.db = FAISS.from_documents(documents, cached_embedder)
    t1 = time()
    print("loading time ",t1-t0)

retrival(query, k=10)

Perform retrieval

Parameters:
  • query (str) –

    the query to search for in the retriever.

  • k (int, default: 10 ) –

    number of documents to return.

Returns:
  • list

    a list of dictionaries containing information about the retrieved documents.

Source code in src/racp/retriver.py
def retrival(self, query, k=10):
    """Perform retrieval

    Args:
        query (str): the query to search for in the retriever.
        k (int): number of documents to return.

    Returns:
        list: a list of dictionaries containing information about the retrieved documents.
    """
    docs = self.db.similarity_search_with_relevance_scores(query,k=k*2)
    # 现在这个result 里面 arxiv id有重复,请你帮我去掉重复的
    result = [{'Papername':doc[0].metadata['title'],'arxiv_id':doc[0].metadata['source'],'quality':doc[0].metadata['quality'],'relevance':doc[1]} for doc in docs if doc[1]>0]
    # 如果你希望按照原始列表中的顺序保留其他字段,可以使用以下代码:
    arxivids = list(set([doc[0].metadata['source'] for doc in docs if doc[1] > 0]))

    unique_result = []
    for doc in docs:
        if doc[1] > 0:
            arxiv_id = doc[0].metadata['source']
            if arxiv_id  in arxivids:
                # unique_result.append({'Papername': doc[0].metadata['title'], 'arxiv_id': doc[0].metadata['source'], 'quality': doc[0].metadata['quality'], 'relevance': doc[1]})
                unique_result.append({'Papername': doc[0].metadata['title'], 'arxiv_id':doc[0].metadata['source'], 'relevance': doc[1]})
                arxivids.remove(arxiv_id) 
    print(unique_result)
    return unique_result