# The MIT License (MIT)
# Copyright © 2024 Opentensor Foundation
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
# the Software.
#
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import base64
from typing import Optional, Union
import msgpack
import msgpack_numpy
import numpy as np
from pydantic import ConfigDict, BaseModel, Field, field_validator
from bittensor.utils.registration import torch, use_torch
[docs]
class DTypes(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.torch: bool = False
        self.update(
            {
                "float16": np.float16,
                "float32": np.float32,
                "float64": np.float64,
                "uint8": np.uint8,
                "int16": np.int16,
                "int8": np.int8,
                "int32": np.int32,
                "int64": np.int64,
                "bool": bool,
            }
        )
[docs]
    def __getitem__(self, key):
        self._add_torch()
        return super().__getitem__(key) 
[docs]
    def __contains__(self, key):
        self._add_torch()
        return super().__contains__(key) 
[docs]
    def _add_torch(self):
        if self.torch is False:
            torch_dtypes = {
                "torch.float16": torch.float16,
                "torch.float32": torch.float32,
                "torch.float64": torch.float64,
                "torch.uint8": torch.uint8,
                "torch.int16": torch.int16,
                "torch.int8": torch.int8,
                "torch.int32": torch.int32,
                "torch.int64": torch.int64,
                "torch.bool": torch.bool,
            }
            self.update(torch_dtypes)
            self.torch = True 
 
dtypes = DTypes()
[docs]
def cast_dtype(raw: Union[None, np.dtype, "torch.dtype", str]) -> Optional[str]:
    """
    Casts the raw value to a string representing the `numpy data type <https://numpy.org/doc/stable/user/basics.types.html>`_, or the `torch data type <https://pytorch.org/docs/stable/tensor_attributes.html>`_ if using torch.
    Args:
        raw (Union[None, numpy.dtype, torch.dtype, str]): The raw value to cast.
    Returns:
        str: The string representing the numpy/torch data type.
    Raises:
        Exception: If the raw value is of an invalid type.
    """
    if not raw:
        return None
    if use_torch() and isinstance(raw, torch.dtype):
        return dtypes[raw]
    elif isinstance(raw, np.dtype):
        return dtypes[raw]
    elif isinstance(raw, str):
        if use_torch():
            assert raw in dtypes, f"{raw} not a valid torch type in dict {dtypes}"
            return raw
        else:
            assert raw in dtypes, f"{raw} not a valid numpy type in dict {dtypes}"
            return raw
    else:
        raise Exception(
            f"{raw} of type {type(raw)} does not have a valid type in Union[None, numpy.dtype, torch.dtype, str]"
        ) 
[docs]
def cast_shape(raw: Union[None, list[int], str]) -> Optional[Union[str, list]]:
    """
    Casts the raw value to a string representing the tensor shape.
    Args:
        raw (Union[None, list[int], str]): The raw value to cast.
    Returns:
        str: The string representing the tensor shape.
    Raises:
        Exception: If the raw value is of an invalid type or if the list elements are not of type int.
    """
    if not raw:
        return None
    elif isinstance(raw, list):
        if len(raw) == 0 or isinstance(raw[0], int):
            return raw
        else:
            raise Exception(f"{raw} list elements are not of type int")
    elif isinstance(raw, str):
        shape = list(map(int, raw.split("[")[1].split("]")[0].split(",")))
        return shape
    else:
        raise Exception(
            f"{raw} of type {type(raw)} does not have a valid type in Union[None, list[int], str]"
        ) 
[docs]
class tensor:
    def __new__(cls, tensor: Union[list, "np.ndarray", "torch.Tensor"]):
        if isinstance(tensor, list) or isinstance(tensor, np.ndarray):
            tensor = torch.tensor(tensor) if use_torch() else np.array(tensor)
        return Tensor.serialize(tensor_=tensor) 
[docs]
class Tensor(BaseModel):
    """
    Represents a Tensor object.
    Args:
        buffer (Optional[str]): Tensor buffer data.
        dtype (str): Tensor data type.
        shape (list[int]): Tensor shape.
    """
    model_config = ConfigDict(validate_assignment=True)
[docs]
    def tensor(self) -> Union[np.ndarray, "torch.Tensor"]:
        return self.deserialize() 
[docs]
    def tolist(self) -> list[object]:
        return self.deserialize().tolist() 
[docs]
    def numpy(self) -> "np.ndarray":
        return (
            self.deserialize().detach().numpy() if use_torch() else self.deserialize()
        ) 
[docs]
    def deserialize(self) -> Union["np.ndarray", "torch.Tensor"]:
        """
        Deserializes the Tensor object.
        Returns:
            np.array or torch.Tensor: The deserialized tensor object.
        Raises:
            Exception: If the deserialization process encounters an error.
        """
        shape = tuple(self.shape)
        buffer_bytes = base64.b64decode(self.buffer.encode("utf-8"))
        numpy_object = msgpack.unpackb(
            buffer_bytes, object_hook=msgpack_numpy.decode
        ).copy()
        if use_torch():
            torch_object = torch.as_tensor(numpy_object)
            # Reshape does not work for (0) or [0]
            if not (len(shape) == 1 and shape[0] == 0):
                torch_object = torch_object.reshape(shape)
            return torch_object.type(dtypes[self.dtype])
        else:
            # Reshape does not work for (0) or [0]
            if not (len(shape) == 1 and shape[0] == 0):
                numpy_object = numpy_object.reshape(shape)
            return numpy_object.astype(dtypes[self.dtype]) 
[docs]
    @staticmethod
    def serialize(tensor_: Union["np.ndarray", "torch.Tensor"]) -> "Tensor":
        """
        Serializes the given tensor.
        Args:
            tensor_ (np.array or torch.Tensor): The tensor to serialize.
        Returns:
            :func:`Tensor`: The serialized tensor.
        Raises:
            Exception: If the serialization process encounters an error.
        """
        dtype = str(tensor_.dtype)
        shape = list(tensor_.shape)
        if len(shape) == 0:
            shape = [0]
        tensor__ = tensor_.cpu().detach().numpy().copy() if use_torch() else tensor_
        data_buffer = base64.b64encode(
            msgpack.packb(tensor__, default=msgpack_numpy.encode)
        ).decode("utf-8")
        return Tensor(buffer=data_buffer, shape=shape, dtype=dtype) 
    # Represents the tensor buffer data.
    buffer: Optional[str] = Field(
        default=None,
        title="buffer",
        description="Tensor buffer data. This field stores the serialized representation of the tensor data.",
        examples=["0x321e13edqwds231231231232131"],
        frozen=True,
        repr=False,
    )
    # Represents the data type of the tensor.
    dtype: str = Field(
        title="dtype",
        description="Tensor data type. This field specifies the data type of the tensor, such as numpy.float32 or torch.int64.",
        examples=["np.float32"],
        frozen=True,
        repr=True,
    )
    # Represents the shape of the tensor.
    shape: list[int] = Field(
        title="shape",
        description="Tensor shape. This field defines the dimensions of the tensor as a list of integers, such as [10, 10] for a 2D tensor with shape (10, 10).",
        examples=[10, 10],
        frozen=True,
        repr=True,
    )
    # Extract the represented shape of the tensor.
    _extract_shape = field_validator("shape", mode="before")(cast_shape)
    # Extract the represented data type of the tensor.
    _extract_dtype = field_validator("dtype", mode="before")(cast_dtype)