Spaces:
Runtime error
Runtime error
| from abc import ABC | |
| from typing import Any, Dict, List, Literal, TypedDict, Union, cast | |
| from pydantic import BaseModel, PrivateAttr | |
| class BaseSerialized(TypedDict): | |
| """Base class for serialized objects.""" | |
| lc: int | |
| id: List[str] | |
| class SerializedConstructor(BaseSerialized): | |
| """Serialized constructor.""" | |
| type: Literal["constructor"] | |
| kwargs: Dict[str, Any] | |
| class SerializedSecret(BaseSerialized): | |
| """Serialized secret.""" | |
| type: Literal["secret"] | |
| class SerializedNotImplemented(BaseSerialized): | |
| """Serialized not implemented.""" | |
| type: Literal["not_implemented"] | |
| class Serializable(BaseModel, ABC): | |
| """Serializable base class.""" | |
| def lc_serializable(self) -> bool: | |
| """ | |
| Return whether or not the class is serializable. | |
| """ | |
| return False | |
| def lc_namespace(self) -> List[str]: | |
| """ | |
| Return the namespace of the langchain object. | |
| eg. ["langchain", "llms", "openai"] | |
| """ | |
| return self.__class__.__module__.split(".") | |
| def lc_secrets(self) -> Dict[str, str]: | |
| """ | |
| Return a map of constructor argument names to secret ids. | |
| eg. {"openai_api_key": "OPENAI_API_KEY"} | |
| """ | |
| return dict() | |
| def lc_attributes(self) -> Dict: | |
| """ | |
| Return a list of attribute names that should be included in the | |
| serialized kwargs. These attributes must be accepted by the | |
| constructor. | |
| """ | |
| return {} | |
| class Config: | |
| extra = "ignore" | |
| _lc_kwargs = PrivateAttr(default_factory=dict) | |
| def __init__(self, **kwargs: Any) -> None: | |
| super().__init__(**kwargs) | |
| self._lc_kwargs = kwargs | |
| def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]: | |
| if not self.lc_serializable: | |
| return self.to_json_not_implemented() | |
| secrets = dict() | |
| # Get latest values for kwargs if there is an attribute with same name | |
| lc_kwargs = { | |
| k: getattr(self, k, v) | |
| for k, v in self._lc_kwargs.items() | |
| if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore | |
| } | |
| # Merge the lc_secrets and lc_attributes from every class in the MRO | |
| for cls in [None, *self.__class__.mro()]: | |
| # Once we get to Serializable, we're done | |
| if cls is Serializable: | |
| break | |
| # Get a reference to self bound to each class in the MRO | |
| this = cast(Serializable, self if cls is None else super(cls, self)) | |
| secrets.update(this.lc_secrets) | |
| lc_kwargs.update(this.lc_attributes) | |
| # include all secrets, even if not specified in kwargs | |
| # as these secrets may be passed as an environment variable instead | |
| for key in secrets.keys(): | |
| secret_value = getattr(self, key, None) or lc_kwargs.get(key) | |
| if secret_value is not None: | |
| lc_kwargs.update({key: secret_value}) | |
| return { | |
| "lc": 1, | |
| "type": "constructor", | |
| "id": [*self.lc_namespace, self.__class__.__name__], | |
| "kwargs": lc_kwargs | |
| if not secrets | |
| else _replace_secrets(lc_kwargs, secrets), | |
| } | |
| def to_json_not_implemented(self) -> SerializedNotImplemented: | |
| return to_json_not_implemented(self) | |
| def _replace_secrets( | |
| root: Dict[Any, Any], secrets_map: Dict[str, str] | |
| ) -> Dict[Any, Any]: | |
| result = root.copy() | |
| for path, secret_id in secrets_map.items(): | |
| [*parts, last] = path.split(".") | |
| current = result | |
| for part in parts: | |
| if part not in current: | |
| break | |
| current[part] = current[part].copy() | |
| current = current[part] | |
| if last in current: | |
| current[last] = { | |
| "lc": 1, | |
| "type": "secret", | |
| "id": [secret_id], | |
| } | |
| return result | |
| def to_json_not_implemented(obj: object) -> SerializedNotImplemented: | |
| """Serialize a "not implemented" object. | |
| Args: | |
| obj: object to serialize | |
| Returns: | |
| SerializedNotImplemented | |
| """ | |
| _id: List[str] = [] | |
| try: | |
| if hasattr(obj, "__name__"): | |
| _id = [*obj.__module__.split("."), obj.__name__] | |
| elif hasattr(obj, "__class__"): | |
| _id = [*obj.__class__.__module__.split("."), obj.__class__.__name__] | |
| except Exception: | |
| pass | |
| return { | |
| "lc": 1, | |
| "type": "not_implemented", | |
| "id": _id, | |
| } |