{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "41d2a0bb-4dc9-4633-84db-c46b1c7927a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "_ = load_dotenv(find_dotenv())\n",
    "openai_api_key = os.environ[\"OPENAI_API_KEY\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7549c5a-a40a-4cf2-a809-f34b027bdd63",
   "metadata": {},
   "source": [
    "## Basic app to extract from a ChatMessage the song and artist a user wants to play"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e42ad50-213a-4e45-a7e9-9dd9a04b4f57",
   "metadata": {},
   "source": [
    "**Define your extraction goal (called \"the response schema\")**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "40dcbf89-52a0-4301-b3e8-848eff6008dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.output_parsers import ResponseSchema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "52a97f7a-c013-4d6e-aace-43cdd8c847ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "response_schemas = [\n",
    "    ResponseSchema(\n",
    "        name=\"singer\",\n",
    "        description=\"name of the singer\"\n",
    "    ),\n",
    "    ResponseSchema(\n",
    "        name=\"song\",\n",
    "        description=\"name of the song\"\n",
    "    )\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c0be924-6097-4595-86ae-9ab0f7363709",
   "metadata": {},
   "source": [
    "**Create the Output Parser that will extract the data**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "214e1af9-5610-42b9-a1bb-d16c8656d566",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.output_parsers import StructuredOutputParser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "66cbc4c3-1af7-48d9-b4f8-de14abf457f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_parser = StructuredOutputParser.from_response_schemas(response_schemas)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d31c1af-2241-4543-a7a7-4f28778f998c",
   "metadata": {},
   "source": [
    "**Create the format instructions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2486e439-e50b-4acc-9870-4ebdd846d53b",
   "metadata": {},
   "outputs": [],
   "source": [
    "format_instructions = output_parser.get_format_instructions()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "044908b3-6063-4363-916c-1b5b919eb384",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The output should be a markdown code snippet formatted in the following schema, including the leading and trailing \"```json\" and \"```\":\n",
      "\n",
      "```json\n",
      "{\n",
      "\t\"singer\": string  // name of the singer\n",
      "\t\"song\": string  // name of the song\n",
      "}\n",
      "```\n"
     ]
    }
   ],
   "source": [
    "print(format_instructions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d03a2f22-dcfa-4406-b1ca-af7c86e11957",
   "metadata": {},
   "source": [
    "**Create the ChatPromptTemplate**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e09fc699-3a88-4726-93ec-57dd7cbafe84",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.prompts import ChatPromptTemplate\n",
    "from langchain.prompts import HumanMessagePromptTemplate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5a8c79b8-25d0-49e0-ba0a-785112500c37",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_model = ChatOpenAI()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "3ed7f2f5-d4cf-4501-be77-03608f5e4e64",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"\"\"\n",
    "Given a command from the user,\n",
    "extract the artist and song names\n",
    "{format_instructions}\n",
    "{user_prompt}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "54973460-9795-4a30-823c-2193ecbb4aa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = ChatPromptTemplate(\n",
    "    messages=[\n",
    "        HumanMessagePromptTemplate.from_template(template)\n",
    "    ],\n",
    "    input_variables={\"user_prompt\"},\n",
    "    partial_variables={\"format_instructions\": format_instructions}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09bc3403-bf6a-4adb-ae85-7629705791e9",
   "metadata": {},
   "source": [
    "**Enter the chat message from the user**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "57a0f437-6eca-4fe9-963e-43b0fe273f56",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_message = prompt.format_prompt(\n",
    "    user_prompt=\"I like the song New York, New York by Frank Sinatra\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "16d5cf20-3f41-41f0-9db9-bce0a0104e32",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_chat_message = chat_model(user_message.to_messages())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f78a68c-a21f-421f-9bfc-8239e560fec5",
   "metadata": {},
   "source": [
    "**Extract singer and song from the user message**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "dd865021-3a86-44f0-9938-32a485a8e4ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "extraction = output_parser.parse(user_chat_message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "91f0aa98-b68d-43cd-8415-99e517834992",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'singer': 'Frank Sinatra', 'song': 'New York, New York'}\n"
     ]
    }
   ],
   "source": [
    "print(extraction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "6424f00e-69e7-4ebb-8bd7-1fbddea07603",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'dict'>\n"
     ]
    }
   ],
   "source": [
    "print(type(extraction))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2cf8e09-b72f-4dc8-98fd-1abbb1eba3b0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
