|
import os |
|
import csv |
|
import boto3 |
|
import json |
|
from datetime import date |
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
delimeter = ";" |
|
|
|
region_name = os.environ.get('REGION_NAME', 'us-east-1') |
|
model_id = os.environ.get('MODEL_ID') |
|
vector_bucket_name = os.environ.get('S3_VECTOR_BUCKET_NAME') |
|
index_name = os.environ.get('S3_VECTOR_INDEX_NAME') |
|
|
|
dataset_path = os.environ.get('INPUT_DATASET_PATH') |
|
|
|
max_input_documents = int(os.environ.get('MAX_INPUT_DOCUMENTS', '-1')) |
|
|
|
bedrock = boto3.client('bedrock-runtime', region_name=region_name) |
|
s3vectors = boto3.client('s3vectors', region_name=region_name) |
|
|
|
|
|
def read_from_csv(file_path, limit=-1): |
|
'''Read files from CSV dataset file.''' |
|
|
|
count = 0 |
|
|
|
with open(file_path, 'r') as fp: |
|
columns = fp.readline() |
|
|
|
reader = csv.DictReader(fp, fieldnames=columns.split(delimeter), delimiter=delimeter) |
|
|
|
for row in reader: |
|
yield row |
|
count += 1 |
|
|
|
if not limit == -1 and count >= limit: |
|
break |
|
|
|
|
|
def create_embedding(text): |
|
'''Use bedrock to get the embedding matrix for the text''' |
|
|
|
response = bedrock.invoke_model( |
|
modelId=model_id, |
|
body=json.dumps({"inputText": text}) |
|
) |
|
|
|
response_body = json.loads(response['body'].read()) |
|
return response_body["embedding"] |
|
|
|
|
|
def create_vector(doc): |
|
|
|
embedding = create_embedding(doc['Description']) |
|
|
|
local_date = date.fromisoformat(doc['Local_Date_Main_Event']) |
|
|
|
return { |
|
"key": doc['Occurrence_Id'], |
|
"data": { "float32": embedding }, |
|
"metadata": { |
|
"occurrence_id": doc['Occurrence_Id'], |
|
"local_date": doc['Local_Date_Main_Event'], |
|
"local_date_year": local_date.year, |
|
"local_date_month": local_date.month, |
|
"state_affected": doc['Coastal_State_Affected'], |
|
"source_text": doc['Description'], |
|
} |
|
} |
|
|
|
|
|
def load_documents(file_path): |
|
'''Load documents into the vector store up to a limit.''' |
|
|
|
vectors = [create_vector(doc) for doc in read_from_csv(file_path, limit=max_input_documents)] |
|
|
|
s3vectors.put_vectors( |
|
vectorBucketName=vector_bucket_name, |
|
indexName=index_name, |
|
vectors=vectors |
|
) |
|
|
|
|
|
def list_vectors(): |
|
'''List the vectors in the index''' |
|
return s3vectors.list_vectors( |
|
vectorBucketName=vector_bucket_name, |
|
indexName=index_name |
|
)['vectors'] |
|
|
|
|
|
def query_documents(query, top_k=3): |
|
'''Query the vector store.''' |
|
|
|
query_embedding = create_embedding(query) |
|
|
|
response = s3vectors.query_vectors( |
|
vectorBucketName=vector_bucket_name, |
|
indexName=index_name, |
|
queryVector={"float32": query_embedding}, |
|
topK=top_k, |
|
returnDistance=True, |
|
returnMetadata=True |
|
) |
|
|
|
return response["vectors"] |
|
|
|
def query_documents_with_filter(query, year, month, top_k=3): |
|
'''Query the vector store using a give filter for year and month''' |
|
|
|
query_embedding = create_embedding(query) |
|
|
|
response = s3vectors.query_vectors( |
|
vectorBucketName=vector_bucket_name, |
|
indexName=index_name, |
|
queryVector={"float32": query_embedding}, |
|
topK=top_k, |
|
filter={ |
|
"$and": [ |
|
{"local_date_year": {"$eq": year}}, |
|
{"local_date_month": {"$eq": month}} |
|
] |
|
}, |
|
returnDistance=True, |
|
returnMetadata=True |
|
) |
|
|
|
return response["vectors"] |
|
|
|
|
|
def delete_vectors(): |
|
'''Delete all vectors by id''' |
|
|
|
docs = read_from_csv(dataset_path, limit=max_input_documents) |
|
|
|
return s3vectors.delete_vectors( |
|
vectorBucketName=vector_bucket_name, |
|
indexName=index_name, |
|
keys=[doc['Occurrence_Id'] for doc in docs] |
|
) |