import os
import json
import requests
import time
from typing import List, Dict, Union
import requests
from datetime import datetime

"""
This sample script shows how to upload a TIF image and its associated IMD metadata file
to the BEOD platform using the API.

The upload creates a user-image raster asset than can be used as the source for an extraction.

It performs the following steps:
1. Logs in the user and gets the access token.
2. Creates a project.
3. Creates an AOI (Area of Interest) from a GeoJSON file.
4. Uploads the TIF image and its associated IMD metadata file in multiple parts.
5. Follows the activity to see the progress of the post-processing of the uploaded asset.
"""

# We do 10 MB chunks (maximum size for S3 multipart upload is 5 GB)
CHUNK_SIZE = 10 * 1024 * 1024

# dev API URL
# API_URL = "https://dev.api.beod.luxcarta.cloud/v1"

# production API URL
API_URL = "https://api.beod.luxcarta.cloud/v1"


def check_response(expected_code: int, response) -> bool:
    """
    Check if the response status code is as expected.
    If not, print the status code, reason, and text of the response.
    """
    
    if response.status_code != expected_code:
        print("FAILED")
        print("STATUS CODE:", response.status_code, " - EXPECTED:", expected_code)
        print("REASON:", response.reason)
        print("TEXT:", response.text)
        return False

    return True


def get_aoi_polygon(geojson_file) -> str:
    """
    Get the WKT polygon from a GeoJSON file.
    The geometry of first feature in the GeoJSON is used.
    """
    
    with open(geojson_file, "r") as f:
        geojson_data = json.load(f)
    if not geojson_data or "features" not in geojson_data or not geojson_data["features"]:
        raise ValueError("GeoJSON file is empty or does not contain features.")
    first_feature = geojson_data["features"][0]
    if "geometry" not in first_feature or "coordinates" not in first_feature["geometry"]:
        raise ValueError("GeoJSON feature does not contain geometry or coordinates.")   
    coordinates = first_feature["geometry"]["coordinates"]
    if first_feature["geometry"]["type"] == "Polygon":
        wkt_polygon = f"POLYGON (({', '.join([' '.join(map(str, coord)) for coord in coordinates[0]])}))"
    else:
        raise ValueError("GeoJSON feature is not a Polygon.")
    return wkt_polygon


def login(username: str, password: str) -> dict:
    """
    This function logs in the user and returns the access token
    Args:
        username: str: The username of the user
        password: str: The password of the user
    Returns:
        dict: A dict continaing the following entries:
        {
            "refresh_token": "eyJjdHkiOiJKV1QiLCJlbmMiOiJB...", 
            "access_token": "eyJraWQiOiJmK2czdXc4R05aZEc0WGs...", 
            "expires_in": 86400, 
            "token_type": "Bearer"
        }
        if the login was successful, None otherwise
    """
    
    url = API_URL + "/auth/login"
    data = {
        "email": username,
        "password": password
    }
    response = requests.post(url, json=data)
    if response.status_code == 200:
        return response.json()
    else:
        return None


def create_project(token: str, project_name:str) -> int:
    """
    Create a project with the given name.
    Args:
        token: str: The access token of the user
        project_name: str: The name of the project
    Returns:
        int: The id of the created project if the project was created successfully, None otherwise
    """
    
    url = API_URL + "/users/me/projects"

    # use the token as a bearer token in the headers
    headers = {
        "Authorization": f"Bearer {token}"
    }
    data = {
        "name": project_name
    }
    response = requests.post(url, json=data, headers=headers)
    
    # Note: 
    #  - in case of success a 201 status code is returned (which means created)
    #  - the id of the created project is returned in the response
    if response.status_code == 201:
        return response.json()["project_id"]
    else:
        return None


def create_aoi(token: str, project_id:str, wkt_geometry: str) -> int:
    """
    Create an aoi with the given geometry for the specified project.
    Args:
        token: str: The access token of the user
        project_id: int: The id of the project
        wkt_geometry: str: The wkt geometry of the aoi
    Returns:
        int: The id of the created aoi if the aoi was created successfully, None otherwise
    """
    
    url = API_URL + f"/users/me/projects/{project_id}/assets/aoi"

    # use the token as a bearer token in the headers
    headers = {
        "Authorization": f"Bearer {token}"
    }
    data = {
        "geometry": wkt_geometry
    }
    response = requests.post(url, json=data, headers=headers)
    
    # Note: 
    #  - in case of success a 201 status code is returned (wihch means created)
    #  - the id of the created asset aoi is returned in the response
    if response.status_code == 201:
        return response.json()["asset_id"]
    else:
        return None


def follow_activity(token: str, project_id:str, activity_id: int) -> int:
    """
    Follow an activity in a project until it terminates.
    Args:
        token: str: The access token of the user
        project_id: str: The id of the project
        activity_id: int: The id of the activity to follow
    Returns:
        True if the activity terminated successfully, False otherwise
    """
    
    url = API_URL + f"/users/me/projects/{project_id}/activities/{activity_id}"

    # use the token as a bearer token in the headers
    headers = {
        "Authorization": f"Bearer {token}"
    }
    
    while True:
        response = requests.get(url, headers=headers)

        if response.status_code != 200:
            print('Unable to track the activity')
            return False
        
        # get status and progress
        json_response = response.json()
        state = json_response["state"]
        progress = json_response["progress"]
        details = f", details: {json_response.get('details', '')}" if 'details' in json_response else ''
        print(f'> Activity state: {state}{details}, progress: {progress}')
        if state == "SUCCESSFUL":
            return True
        elif state == "FAILED":
            return False
        
        time.sleep(1)


# --------------------------------------------------------------------------------
# DOWNLOAD PART
# --------------------------------------------------------------------------------

def upload_one_file(
    filename: str,
    parts: List[int],
    urls: List[str],
) -> List[Dict[str, Union[int, str]]]:
    """ 
    Upload a single file in multiple parts to the given URLs.
    Note that in this sample tutorial script we upload parts sequentially, this
    is no the most efficient way to upload files as parallel upload is fully
    supported by AWS. 
    Args:
        filename: str: The name of the file to upload.
        parts: List[int]: The list of part numbers to upload.
        urls: List[str]: The list of URLs to upload the parts to.
    Returns:
        List[Dict[str, Union[int, str]]]: A list of dictionaries containing the 
            Etag and part number for each uploaded part.
    """
    
    uploaded_parts = []
    with open(filename, "rb") as f:
        for part_index, part_number in enumerate(parts):
            # read the part
            offset = (part_number - 1) * CHUNK_SIZE
            f.seek(offset)
            data = f.read(CHUNK_SIZE)

            # upload the part
            print(f"uploading part {part_number} of {filename}...")
            headers = {}
            response = requests.put(urls[part_index], data=data, headers=headers)

            if not check_response(200, response):
                return []

            # keep track of the uploaded parts as we will need them to complete the upload
            etag = response.headers["ETag"]
            uploaded_parts.append({"ETag": etag, "PartNumber": part_number})

    return uploaded_parts


def get_local_files_chunks(files: List[str], chunk_size: int = CHUNK_SIZE) -> List[Dict]:
    """
    Get the local files and their chunks for upload.
    Args:
        files: List[str]: The list of local file paths to upload.
        chunk_size: int: The size of each chunk in bytes (default is 10 MB).
    Returns:
        Tuple[List[Dict], int]: A tuple containing a list of dictionaries with file information and
                                the total number of chunks to upload.
    """
    
    upload_files = []
    for file in files:
        size = os.path.getsize(file)
        parts_count = size // chunk_size
        if size % chunk_size > 0:
            parts_count += 1
        upload_files.append({
            "name": file,
            "size": size,
            "parts_count": parts_count,
        })
        
    return upload_files


def upload_image_asset(
    token: str,
    project_id: int,
    files: List[str],
):
    # ------------------------------------------------------------------------------------------------
    # the first step is to compute the chunks for each file to upload and to 
    # call the assets/upload/initialize
    
    files_chunks = get_local_files_chunks(files)
    
    # initialize the multi-files, multi-parts upload
    headers = {"Authorization": f"Bearer {token}"}

    # initialize the multi-files, multi-parts upload
    initialize_request = {
        "asset_name": "my uploaded image",
        "asset_type": "raster", # raster, vector or aoi
        "asset_files": files_chunks
    }

    response = requests.post(
        f"{API_URL}/users/me/projects/{project_id}/assets/upload/initialize",
        json=initialize_request,
        headers=headers,
    )

    if not check_response(201, response):
        return

    initialize_response = response.json()
    uploaded_asset_id = initialize_response["asset_id"]
    activity_id = initialize_response["activity_id"]

    # ------------------------------------------------------------------------------------------------
    # the second step is to get for each file the upload URLs for the parts with the
    # assets/upload/multi-part-urls route, which returns a list of URLs for each part of each file.
    # we will then upload the parts of each file to the corresponding URLs.
    # when all parts of a file are uploaded, we complete the upload by calling the route
    # assets/upload/complete with the upload_id, file_key and the list of uploaded parts.
    
    # get the upload urls for the parts

    for file_index, upload_file in enumerate(list(initialize_response['uploads'].keys())):
        file_key, upload_id = initialize_response["uploads"][upload_file]
        parts = list(range(1, files_chunks[file_index]["parts_count"] + 1))
 
        # get the upload urls for the file parts
        get_multi_part_urls_request = {
            "upload_id": upload_id,
            "file_key": file_key,
            "parts": parts,
        }
        
        response = requests.post(
            f"{API_URL}/users/me/projects/{project_id}/assets/"
            f"{uploaded_asset_id}/upload/multi-part-urls",
            json=get_multi_part_urls_request,
            headers=headers,
        )

        if not check_response(200, response):
            return

        get_multi_part_urls_response = response.json()

        # upload the file parts after parts
        uploaded_parts = upload_one_file(
            upload_file,
            get_multi_part_urls_request["parts"],
            get_multi_part_urls_response["urls"]
        )

        # complete the upload by calling the assets/upload/complete route
        complete_upload_request = {
            "upload_id": upload_id,
            "file_key": file_key,
            "parts": uploaded_parts
        }

        response = requests.post(
            f"{API_URL}/users/me/projects/{project_id}/assets/"
            f"{uploaded_asset_id}/upload/complete",
            json=complete_upload_request,
            headers=headers,
        )

        if not check_response(201, response):
            return

        complete_response = response.json()
        print(f"File {file_key} - status = {complete_response['status']}")

    # ------------------------------------------------------------------------------------------------
    # the last step is to follow the activity to see the progress of the post-processing of
    # the uploaded asset. The activity id is returned in the response of the initialization step.
    
    print("Following the activity to see the progress of the post-processing of the uploaded asset...")
    follow_activity(token, project_id, activity_id)


if __name__ == "__main__":
    
    # local data files
    aoi_file = "data/aoi.geojson"
    tif_file = "data/image.tif"
    imd_file = "data/image.imd"
    
    # credentials must be defined in the environment variables
    if 'TEST_USER' not in os.environ or 'TEST_PASSWORD' not in os.environ:
        print("Please set the TEST_USER and TEST_PASSWORD environment variables.")
        exit(1)
        
    # get the user credentials from the environment variables
    user = os.environ['TEST_USER']
    password = os.environ['TEST_PASSWORD']
    
    # first setup a project with an AOI which fits the uploaded image
    aoi_wkt_polygon = get_aoi_polygon(geojson_file=aoi_file)    
    token = login(user, password)["access_token"]
    assert token is not None, "Login failed, please check your credentials"
    
    project_id = create_project(token, f"Upload Tutorial project {datetime.now().strftime('%Y-%m-%d/%H:%M:%S')}")
    assert project_id is not None, "Project creation failed"
    
    aoi_id = create_aoi(token, project_id, aoi_wkt_polygon)
    assert aoi_id is not None, "AOI creation failed"
    
    # upload the TIF image and its associated IMD metadata file
    files = [tif_file, imd_file]
    upload_image_asset(token, project_id, files)
