Last active
November 23, 2023 08:31
-
-
Save nreimers/4fa0fdac578de4288399929a23fa9f3e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # 1) Install dependencies: pip install cohere datasets elasticsearch==8.6.2 | |
| # 2) Start a local Elasticsearch server: docker run -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" elasticsearch:8.6.2 | |
| # 3) Get your Cohere API key and past it below | |
| from elasticsearch import Elasticsearch, helpers | |
| import cohere | |
| from datasets import load_dataset | |
| # Get your cohere API key on: www.cohere.com | |
| co = cohere.Client("<<YOUR_COHERE_API_KEY>>>") | |
| # Connect to elastic | |
| es = Elasticsearch("http://localhost:9200") | |
| # If the ES index does not exist yet, load simple English Wikipedia dataset and index it | |
| index = "wikipedia" | |
| if not es.indices.exists(index=index): | |
| print("Load dataset") | |
| data = load_dataset(f"Cohere/wikipedia-22-12", "simple", split='train', streaming=True) | |
| all_docs = map(lambda row : {"_index": index, "_id": row['id'], "_source": {"text": row['text']}}, data) | |
| print("Start index docs. This might take few minutes.") | |
| helpers.bulk(es, all_docs) | |
| # Traditional lexical search with ES | |
| query = "Cats lifespan" | |
| # Retrieve top-100 documents from ES lexical search | |
| resp = es.search(index=index, size=100, query={'query_string': {'query': query}}) | |
| docs = [hit['_source']['text'] for hit in resp['hits']['hits']] | |
| print("Elasticsearch Lexical Search results:") | |
| for doc in docs[0:3]: | |
| print(doc) | |
| print("-----") | |
| # Re-Rank them with cohere | |
| rerank_hits = co.rerank(query=query, documents=docs, top_n=3, model='rerank-multilingual-02') | |
| print("\n===========") | |
| print("ReRank results:") | |
| for hit in rerank_hits: | |
| print(docs[hit.index]) | |
| print("-----") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment