diff --git a/ActorCritic2.ipynb b/ActorCritic2.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f7cd09fe214fb05684a30e5b3034521c0e8a30eb --- /dev/null +++ b/ActorCritic2.ipynb @@ -0,0 +1,3395 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 360, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import torch\n", + "import read_maze as maze\n", + "from enum import Enum\n", + "import torch.optim\n", + "import torch.nn as nn\n", + "import numpy as np\n", + "import torch.nn.functional as F\n", + "import random\n", + "from torch.distributions import Categorical\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 361, + "outputs": [], + "source": [ + "maze.load_maze()" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 362, + "outputs": [], + "source": [ + "class envirnoment():\n", + "\n", + " def __init__(self,start, end):\n", + " self.start = torch.clone(start)\n", + " self.location = torch.clone(start)\n", + " self.end = torch.clone(end)\n", + " self.around = torch.from_numpy(maze.get_local_maze_information(*self.location))\n", + " self.action_space = 5\n", + " self.time_step = torch.tensor([0])\n", + "\n", + "\n", + " def reset(self):\n", + " self.location = torch.clone(self.start)\n", + " self.around = torch.from_numpy(maze.get_local_maze_information(*self.location))\n", + " self.time_step = torch.tensor([0])\n", + "\n", + " state = torch.cat([self.location.unsqueeze(0), self.around.view(-1,18), self.time_step.unsqueeze(0)],dim=1)\n", + " return state\n", + "\n", + "\n", + " def take_action(self, action):\n", + " # NEXT IS WHERE THE AGENT WILL LAND AFTER PERFORMING AN ACTION\n", + " next = torch.clone(self.location)\n", + " reward = 0\n", + "\n", + " if action == 1:\n", + " if self.around[1][0][0] == 1:# and self.around[0][1][1] == 0:\n", + " next[1] += -1\n", + " reward -= 1\n", + " #print('up')\n", + "\n", + " elif action == 2:\n", + "\n", + " if self.around[2][1][0] == 1:# and self.around[1][2][1] == 0:\n", + " next[0] += 1\n", + " reward += 1\n", + " #print('right')\n", + "\n", + " elif action == 3:\n", + " if self.around[1][2][0] == 1:# and self.around[2][1][1] == 0:\n", + " next[1] += 1\n", + " reward += 1\n", + " #print('down')\n", + "\n", + " elif action == 4:\n", + "\n", + " if self.around[0][1][0] == 1:# and self.around[1][0][1] == 0:\n", + " next[0] += -1\n", + " reward -= 1\n", + " #print('left')\n", + "\n", + " # IF THE AGENT HAS NOT CHANGED LOCATION THIS ACTION\n", + " if torch.equal(self.location, next):\n", + " reward = -0.5\n", + "\n", + " self.location = torch.clone(next)\n", + " self.around = torch.from_numpy(maze.get_local_maze_information(*self.location))\n", + "\n", + "\n", + " # IF THE AGENT SUCCESSFULLY GETS TO THE GOAL\n", + " done = False\n", + " if torch.equal(self.location, self.end):\n", + " reward = 100\n", + " done = True\n", + "\n", + " # SUBTRACT TIME BASED PENALTY FROM REWARD\n", + " reward -= self.time_step * 0.001\n", + " self.time_step +=1\n", + "\n", + "\n", + "\n", + " state = torch.cat([self.location.unsqueeze(0), self.around.view(-1,18), self.time_step.unsqueeze(0)],dim=1)\n", + "\n", + " return state, reward, done\n" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 363, + "outputs": [], + "source": [ + "def print_alt_maze(x,y):\n", + " k = maze.get_local_maze_information(x,y)\n", + " r1=''\n", + " r2=''\n", + " r3=''\n", + " for i in k:\n", + "\n", + " if i[0][0]==1:\n", + " r1 += 'O '\n", + " else:\n", + " r1 += 'X '\n", + " if i[1][0]==1:\n", + " r2 += 'O '\n", + " else:\n", + " r2 += 'X '\n", + " if i[2][0]==1:\n", + " r3 += 'O '\n", + " else:\n", + " r3 += 'X '\n", + "\n", + " print('======')\n", + " print(r1)\n", + " print(r2)\n", + " print(r3)\n", + " print('======')\n", + "env2 = envirnoment(torch.tensor([1,1]), torch.tensor([2,1]) )" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 364, + "outputs": [], + "source": [ + "class ActorCritic(nn.Module):\n", + "\n", + " def __init__(self, in_size, num_of_actions):\n", + " super(ActorCritic, self).__init__()\n", + "\n", + " self.fc0 = nn.Linear(in_size,64)\n", + " self.fc1 = nn.Linear(64, 128)\n", + " self.fc2 = nn.Linear(128, 256)\n", + "\n", + " self.actor = nn.Linear(256, num_of_actions)\n", + " self.critic = nn.Linear(256,1)\n", + "\n", + "\n", + " def forward(self, x):\n", + "\n", + " out = self.fc0(x)\n", + " out = torch.relu(out)\n", + " out = self.fc1(out)\n", + " out = torch.relu(out)\n", + " out = self.fc2(out)\n", + " out = torch.relu(out)\n", + "\n", + " a_probs = self.actor(out)\n", + " a_probs = F.softmax(a_probs,dim=1)\n", + "\n", + " state_vals = self.critic(out)\n", + "\n", + " return a_probs, state_vals" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 365, + "outputs": [], + "source": [ + "\n", + "def train_loop(model, env, episodes=1000, ep_length= 500, target_update_timing = 128):\n", + " model.train()\n", + " steps = 0\n", + "\n", + " PPO_BUFFER = []\n", + "\n", + "\n", + " for ep in range(episodes):\n", + " print(ep)\n", + " ep_reward = 0\n", + " state = env.reset()\n", + "\n", + " for i in range(ep_length):\n", + " steps += 1\n", + "\n", + " \"\"\" Create a categorical softmax distribution from the outputs of the model to sample actions from \"\"\"\n", + " a_porbs, value = model(state.float().to(device))\n", + " distribution = Categorical(a_porbs)\n", + "\n", + " if steps < target_update_timing:\n", + " action = torch.randint(0,5,(1,))\n", + "\n", + " else:\n", + " action = distribution.sample()\n", + "\n", + " #print(a_list[action])\n", + " log_a_prob = distribution.log_prob(action.to(device))\n", + " #print(distribution)\n", + "\n", + " \"\"\"\n", + " STATE\n", + " \"\"\"\n", + " next_state, reward, done = env.take_action(action)\n", + " ep_reward += reward\n", + " PPO_BUFFER.append([state, action, log_a_prob, reward, done])\n", + " state = next_state\n", + "\n", + " if steps !=0 and steps % target_update_timing == 0:\n", + "\n", + " train_agent_PPO(model, PPO_BUFFER, target_update_timing)\n", + " PPO_BUFFER = []\n", + "\n", + " if i == ep_length-1:\n", + " print(env.location)\n", + " print(str(ep_reward/i))\n", + "\n", + " if done:\n", + " print(env.location)\n", + " print(str(ep_reward/i))\n", + " break" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 366, + "outputs": [], + "source": [ + "def train_agent_PPO(model, PPO_BUFFER, batch_size,gamma=0.99, cycles = 1):\n", + "\n", + " model_copy = ActorCritic(state_size,5).to(device)\n", + " model_copy.load_state_dict(model.state_dict())\n", + " optim = torch.optim.Adam(model_copy.parameters(), lr =0.01)\n", + " criterion = nn.MSELoss()\n", + "\n", + "\n", + " # Calculate discounted rewards\n", + " discounted_rewards = []\n", + " discounted_reward = 0\n", + " for experience in reversed(PPO_BUFFER):\n", + " if experience[4] == True:\n", + " discounted_reward = 0\n", + " discounted_reward = experience[3] + gamma * discounted_reward\n", + " discounted_rewards.insert(0, discounted_reward)\n", + "\n", + " discounted_reward_tensor = torch.tensor([discounted_rewards],dtype=torch.float).to(device)\n", + " discounted_reward_tensor = (discounted_reward_tensor - discounted_reward_tensor.mean()) / (discounted_reward_tensor.std() + 1e-10)\n", + " discounted_reward_tensor = discounted_reward_tensor.view(-1,1)\n", + " #print(\"DISCOUNT TENSOR\" + str(discounted_reward_tensor.shape))\n", + "\n", + " state_tensor = torch.cat([t[0] for t in PPO_BUFFER])\n", + " state_tensor = state_tensor.view(-1,21).to(device).float()\n", + " #print(\"STATE TENSOR\" + str(state_tensor.shape))\n", + "\n", + " action_tensor = torch.tensor([t[1] for t in PPO_BUFFER])\n", + " action_tensor = action_tensor.view(-1,1).to(device)\n", + " #print(\"ACTION TENSOR\" + str(action_tensor))\n", + "\n", + " log_tensor = torch.tensor([t[1] for t in PPO_BUFFER])\n", + " log_tensor = log_tensor.view(-1,1).to(device)\n", + " #print(\"ACTION TENSOR\" + str(log_tensor))\n", + "\n", + "\n", + " #torch.autograd.set_detect_anomaly(True)\n", + " for i in range(cycles):\n", + "\n", + " new_policy, values = model_copy(state_tensor.float().to(device))\n", + " new_distribution = Categorical(new_policy)\n", + " new_log = new_distribution.log_prob(action_tensor.view(128))\n", + " new_log = new_log.view(128,1)\n", + "\n", + " advantages = discounted_reward_tensor - values\n", + "\n", + " ratios = torch.exp(new_log - log_tensor)\n", + " #print(discounted_reward_tensor.shape)\n", + "\n", + "\n", + " actor_loss = -torch.min(advantages * ratios, torch.clamp(ratios, 1-0.2, 1+0.2)*advantages)\n", + " critic_loss = criterion(discounted_reward_tensor,values)\n", + "\n", + "\n", + " entropy_loss = new_distribution.entropy() # Add cross-entropy loss to encourage exploration\n", + "\n", + " optim.zero_grad()\n", + " total_loss = actor_loss + critic_loss + 0.01 * entropy_loss\n", + " total_loss = total_loss.mean()\n", + " total_loss.backward()\n", + " optim.step()\n", + "\n", + " #print(\"Example actor loss:\", str(actor_loss.mean()))\n", + " #print(\"Example critic loss:\", str(critic_loss.mean()))\n", + " model.load_state_dict(model_copy.state_dict())" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 367, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "tensor([5, 1])\n", + "tensor([-0.6989])\n", + "1\n", + "tensor([5, 1])\n", + "tensor([-0.7310])\n", + "2\n", + "tensor([5, 2])\n", + "tensor([-0.7099])\n", + "3\n", + "tensor([5, 1])\n", + "tensor([-0.6207])\n", + "4\n", + "tensor([5, 1])\n", + "tensor([-0.7229])\n", + "5\n", + "tensor([5, 4])\n", + "tensor([-0.7099])\n", + "6\n", + "tensor([5, 1])\n", + "tensor([-0.7310])\n", + "7\n", + "tensor([5, 1])\n", + "tensor([-0.7229])\n", + "8\n", + "tensor([5, 1])\n", + "tensor([-0.7310])\n", + "9\n", + "tensor([5, 5])\n", + "tensor([-0.6969])\n", + "10\n", + "tensor([5, 1])\n", + "tensor([-0.7049])\n", + "11\n", + "tensor([5, 2])\n", + "tensor([-0.7039])\n", + "12\n", + "tensor([5, 1])\n", + "tensor([-0.7330])\n", + "13\n", + "tensor([5, 1])\n", + "tensor([-0.7330])\n", + "14\n", + "tensor([5, 1])\n", + "tensor([-0.7310])\n", + "15\n", + "tensor([5, 1])\n", + "tensor([-0.7370])\n", + "16\n", + "tensor([5, 2])\n", + "tensor([-0.7360])\n", + "17\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "18\n", + "tensor([5, 1])\n", + "tensor([-0.7370])\n", + "19\n", + "tensor([5, 1])\n", + "tensor([-0.7370])\n", + "20\n", + "tensor([5, 3])\n", + "tensor([-0.7310])\n", + "21\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "22\n", + "tensor([5, 2])\n", + "tensor([-0.7360])\n", + "23\n", + "tensor([5, 1])\n", + "tensor([-0.7370])\n", + "24\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "25\n", + "tensor([5, 4])\n", + "tensor([-0.7280])\n", + "26\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "27\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "28\n", + "tensor([5, 1])\n", + "tensor([-0.7370])\n", + "29\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "30\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "31\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "32\n", + "tensor([5, 1])\n", + "tensor([-0.7370])\n", + "33\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "34\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "35\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "36\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "37\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "38\n", + "tensor([5, 2])\n", + "tensor([-0.7360])\n", + "39\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "40\n", + "tensor([5, 1])\n", + "tensor([-0.7350])\n", + "41\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "42\n", + "tensor([5, 1])\n", + "tensor([-0.7370])\n", + "43\n", + "tensor([5, 2])\n", + "tensor([-0.7340])\n", + "44\n", + "tensor([5, 2])\n", + "tensor([-0.7360])\n", + "45\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "46\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "47\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "48\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "49\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "50\n", + "tensor([5, 2])\n", + "tensor([-0.7360])\n", + "51\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "52\n", + "tensor([5, 3])\n", + "tensor([-0.7310])\n", + "53\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "54\n", + "tensor([5, 3])\n", + "tensor([-0.7310])\n", + "55\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "56\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "57\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "58\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "59\n", + "tensor([5, 1])\n", + "tensor([-0.7370])\n", + "60\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "61\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "62\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "63\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "64\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "65\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "66\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "67\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "68\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "69\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "70\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "71\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "72\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "73\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "74\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "75\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "76\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "77\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "78\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "79\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "80\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "81\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "82\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "83\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "84\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "85\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "86\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "87\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "88\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "89\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "90\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "91\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "92\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "93\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "94\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "95\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "96\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "97\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "98\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "99\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "100\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "101\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "102\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "103\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "104\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "105\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "106\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "107\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "108\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "109\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "110\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "111\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "112\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "113\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "114\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "115\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "116\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "117\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "118\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "119\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "120\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "121\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "122\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "123\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "124\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "125\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "126\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "127\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "128\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "129\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "130\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "131\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "132\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "133\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "134\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "135\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "136\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "137\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "138\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "139\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "140\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "141\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "142\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "143\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "144\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "145\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "146\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "147\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "148\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "149\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "150\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "151\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "152\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "153\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "154\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "155\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "156\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "157\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "158\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "159\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "160\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "161\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "162\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "163\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "164\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "165\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "166\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "167\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "168\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "169\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "170\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "171\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "172\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "173\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "174\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "175\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "176\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "177\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "178\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "179\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "180\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "181\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "182\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "183\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "184\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "185\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "186\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "187\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "188\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "189\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "190\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "191\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "192\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "193\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "194\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "195\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "196\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "197\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "198\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "199\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "200\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "201\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "202\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "203\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "204\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "205\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "206\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "207\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "208\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "209\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "210\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "211\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "212\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "213\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "214\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "215\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "216\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "217\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "218\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "219\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "220\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "221\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "222\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "223\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "224\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "225\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "226\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "227\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "228\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "229\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "230\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "231\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "232\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "233\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "234\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "235\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "236\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "237\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "238\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "239\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "240\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "241\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "242\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "243\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "244\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "245\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "246\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "247\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "248\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "249\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "250\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "251\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "252\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "253\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "254\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "255\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "256\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "257\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "258\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "259\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "260\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "261\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "262\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "263\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "264\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "265\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "266\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "267\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "268\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "269\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "270\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "271\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "272\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "273\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "274\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "275\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "276\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "277\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "278\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "279\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "280\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "281\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "282\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "283\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "284\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "285\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "286\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "287\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "288\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "289\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "290\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "291\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "292\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "293\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "294\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "295\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "296\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "297\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "298\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "299\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "300\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "301\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "302\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "303\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "304\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "305\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "306\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "307\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "308\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "309\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "310\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "311\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "312\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "313\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "314\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "315\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "316\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "317\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "318\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "319\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "320\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "321\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "322\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "323\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "324\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "325\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "326\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "327\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "328\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "329\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "330\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "331\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "332\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "333\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "334\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "335\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "336\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "337\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "338\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "339\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "340\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "341\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "342\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "343\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "344\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "345\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "346\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "347\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "348\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "349\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "350\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "351\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "352\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "353\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "354\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "355\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "356\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "357\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "358\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "359\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "360\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "361\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "362\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "363\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "364\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "365\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "366\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "367\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "368\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "369\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "370\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "371\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "372\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "373\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "374\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "375\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "376\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "377\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "378\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "379\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "380\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "381\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "382\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "383\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "384\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "385\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "386\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "387\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "388\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "389\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "390\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "391\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "392\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "393\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "394\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "395\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "396\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "397\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "398\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "399\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "400\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "401\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "402\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "403\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "404\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "405\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "406\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "407\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "408\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "409\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "410\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "411\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "412\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "413\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "414\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "415\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "416\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "417\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "418\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "419\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "420\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "421\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "422\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "423\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "424\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "425\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "426\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "427\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "428\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "429\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "430\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "431\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "432\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "433\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "434\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "435\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "436\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "437\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "438\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "439\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "440\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "441\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "442\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "443\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "444\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "445\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "446\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "447\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "448\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "449\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "450\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "451\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "452\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "453\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "454\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "455\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "456\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "457\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "458\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "459\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "460\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "461\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "462\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "463\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "464\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "465\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "466\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "467\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "468\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "469\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "470\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "471\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "472\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "473\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "474\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "475\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "476\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "477\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "478\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "479\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "480\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "481\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "482\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "483\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "484\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "485\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "486\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "487\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "488\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "489\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "490\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "491\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "492\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "493\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "494\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "495\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "496\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "497\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "498\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "499\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "500\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "501\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "502\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "503\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "504\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "505\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "506\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "507\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "508\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "509\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "510\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "511\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "512\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "513\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "514\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "515\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "516\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "517\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "518\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "519\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "520\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "521\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "522\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "523\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "524\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "525\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "526\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "527\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "528\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "529\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "530\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "531\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "532\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "533\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "534\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "535\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "536\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "537\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "538\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "539\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "540\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "541\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "542\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "543\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "544\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "545\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "546\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "547\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "548\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "549\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "550\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "551\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "552\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "553\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "554\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "555\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "556\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "557\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "558\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "559\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "560\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "561\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "562\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "563\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "564\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "565\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "566\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "567\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "568\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "569\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "570\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "571\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "572\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "573\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "574\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "575\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "576\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "577\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "578\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "579\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "580\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "581\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "582\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "583\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "584\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "585\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "586\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "587\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "588\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "589\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "590\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "591\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "592\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "593\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "594\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "595\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "596\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "597\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "598\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "599\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "600\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "601\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "602\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "603\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "604\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "605\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "606\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "607\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "608\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "609\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "610\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "611\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "612\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "613\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "614\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "615\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "616\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "617\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "618\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "619\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "620\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "621\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "622\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "623\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "624\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "625\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "626\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "627\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "628\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "629\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "630\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "631\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "632\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "633\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "634\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "635\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "636\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "637\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "638\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "639\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "640\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "641\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "642\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "643\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "644\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "645\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "646\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "647\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "648\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "649\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "650\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "651\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "652\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "653\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "654\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "655\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "656\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "657\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "658\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "659\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "660\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "661\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "662\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "663\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "664\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "665\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "666\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "667\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "668\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "669\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "670\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "671\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "672\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "673\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "674\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "675\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "676\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "677\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "678\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "679\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "680\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "681\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "682\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "683\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "684\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "685\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "686\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "687\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "688\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "689\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "690\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "691\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "692\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "693\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "694\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "695\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "696\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "697\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "698\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "699\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "700\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "701\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "702\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "703\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "704\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "705\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "706\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "707\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "708\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "709\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "710\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "711\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "712\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "713\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "714\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "715\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "716\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "717\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "718\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "719\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "720\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "721\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "722\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "723\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "724\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "725\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "726\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "727\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "728\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "729\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "730\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "731\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "732\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "733\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "734\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "735\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "736\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "737\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "738\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "739\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "740\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "741\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "742\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "743\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "744\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "745\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "746\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "747\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "748\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "749\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "750\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "751\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "752\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "753\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "754\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "755\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "756\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "757\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "758\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "759\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "760\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "761\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "762\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "763\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "764\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "765\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "766\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "767\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "768\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "769\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "770\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "771\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "772\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "773\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "774\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "775\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "776\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "777\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "778\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "779\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "780\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "781\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "782\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "783\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "784\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "785\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "786\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "787\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "788\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "789\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "790\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "791\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "792\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "793\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "794\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "795\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "796\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "797\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "798\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "799\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "800\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "801\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "802\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "803\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "804\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "805\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "806\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "807\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "808\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "809\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "810\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "811\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "812\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "813\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "814\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "815\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "816\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "817\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "818\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "819\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "820\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "821\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "822\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "823\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "824\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "825\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "826\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "827\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "828\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "829\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "830\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "831\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "832\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "833\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "834\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "835\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "836\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "837\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "838\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "839\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "840\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "841\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "842\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "843\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "844\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "845\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "846\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "847\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "848\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "849\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "850\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "851\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "852\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "853\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "854\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "855\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "856\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "857\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "858\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "859\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "860\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "861\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "862\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "863\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "864\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "865\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "866\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "867\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "868\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "869\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "870\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "871\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "872\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "873\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "874\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "875\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "876\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "877\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "878\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "879\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "880\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "881\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "882\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "883\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "884\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "885\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "886\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "887\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "888\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "889\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "890\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "891\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "892\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "893\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "894\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "895\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "896\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "897\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "898\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "899\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "900\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "901\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "902\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "903\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "904\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "905\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "906\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "907\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "908\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "909\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "910\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "911\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "912\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "913\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "914\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "915\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "916\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "917\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "918\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "919\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "920\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "921\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "922\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "923\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "924\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "925\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "926\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "927\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "928\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "929\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "930\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "931\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "932\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "933\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "934\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "935\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "936\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "937\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "938\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "939\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "940\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "941\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "942\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "943\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "944\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "945\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "946\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "947\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "948\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "949\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "950\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "951\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "952\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "953\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "954\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "955\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "956\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "957\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "958\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "959\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "960\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "961\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "962\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "963\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "964\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "965\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "966\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "967\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "968\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "969\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "970\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "971\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "972\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "973\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "974\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "975\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "976\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "977\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "978\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "979\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "980\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "981\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "982\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "983\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "984\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "985\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "986\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "987\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "988\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "989\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "990\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "991\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "992\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "993\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "994\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "995\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "996\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "997\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "998\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n", + "999\n", + "tensor([5, 1])\n", + "tensor([-0.7390])\n" + ] + } + ], + "source": [ + "env = envirnoment(torch.tensor([1,1]), torch.tensor([199, 199]) )\n", + "state_size = 21\n", + "\n", + "gamma = 0.99\n", + "agent = ActorCritic(state_size, 5).to(device)\n", + "optim = torch.optim.Adam(agent.parameters(), lr =0.01)\n", + "\n", + "train_loop(agent,env)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/ActorCriticSolution.py b/ActorCriticSolution.py new file mode 100644 index 0000000000000000000000000000000000000000..b48fa80a969e0ae47217f6c75ecff9fed8b1a4b0 --- /dev/null +++ b/ActorCriticSolution.py @@ -0,0 +1,326 @@ +import torch +import read_maze as maze +from enum import Enum +import torch.optim +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +import random +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +maze.load_maze() + +""" +State is defined as the the x-y coordinate of the agent concatenated with its surrounding information(flattened) and +timestep + +Actions are [Nothing = 0, Up = 1, Right = 2, Down = 3, Left = 4] +An action that reduces the straight line distance to to the goal(Down and Right) gives a reward of +1, actions that +increase it give a reward of -1. If an action doesnt change the agents location, the reward is -0.5. A time penalty of +-0.001 is applied for each timestep. + +""" + +class envirnoment(): + + def __init__(self,start, end): + self.start = torch.clone(start) + self.location = torch.clone(start) + self.end = torch.clone(end) + self.around = torch.from_numpy(maze.get_local_maze_information(*self.location)) + self.action_space = 5 + self.time_step = torch.tensor([0]) + + + def reset(self): + self.location = torch.clone(self.start) + self.around = torch.from_numpy(maze.get_local_maze_information(*self.location)) + self.time_step = torch.tensor([0]) + + state = torch.cat([self.location.unsqueeze(0), self.around.view(-1,18), self.time_step.unsqueeze(0)],dim=1) + return state, self.location, self.around + + + def take_action(self, action): + # NEXT IS WHERE THE AGENT WILL LAND AFTER PERFORMING AN ACTION + next = torch.clone(self.location) + reward = 0 + + if action == 1: + if self.around[1][0][0] == 1:# and self.around[0][1][1] == 0: + next[1] += -1 + reward -= 1 + #print('up') + + elif action == 2: + + if self.around[2][1][0] == 1:# and self.around[1][2][1] == 0: + next[0] += 1 + reward += 1 + #print('right') + + elif action == 3: + if self.around[1][2][0] == 1:# and self.around[2][1][1] == 0: + next[1] += 1 + reward += 1 + #print('down') + + elif action == 4: + + if self.around[0][1][0] == 1:# and self.around[1][0][1] == 0: + next[0] += -1 + reward -= 1 + #print('left') + + # IF THE AGENT HAS NOT CHANGED LOCATION THIS ACTION + if torch.equal(self.location, next): + reward = -0.5 + + self.location = torch.clone(next) + self.around = torch.from_numpy(maze.get_local_maze_information(*self.location)) + + + # IF THE AGENT SUCCESSFULLY GETS TO THE GOAL + done = False + if torch.equal(self.location, self.end): + reward = 10 + done = True + + # SUBTRACT TIME BASED PENALTY FROM REWARD + reward -= self.time_step * 0.01 + self.time_step +=1 + + + + state = torch.cat([self.location.unsqueeze(0), self.around.view(-1,18), self.time_step.unsqueeze(0)],dim=1) + + return state, reward, done, self.location, self.around + + +""" +A function to display the local maze information +""" +def print_maze(x,y): + k = maze.get_local_maze_information(x,y) + r1='' + r2='' + r3='' + for i in k: + + if i[0][0]==1: + r1 += 'O ' + else: + r1 += 'X ' + if i[1][0]==1: + r2 += 'O ' + else: + r2 += 'X ' + if i[2][0]==1: + r3 += 'O ' + else: + r3 += 'X ' + + print('======') + print(r1) + print(r2) + print(r3) + print('======') + + +"""An actor-critic model with a GRU, the actor and critic are different heads of the same model""" + +class ActorCritic(nn.Module): + + def __init__(self, in_size, num_of_actions, gru_hidden_size = 64): + super(ActorCritic, self).__init__() + self.gru_hidden_size = gru_hidden_size + self.gru = nn.GRU(in_size,64,1,batch_first=False,bidirectional=False) + + self.fc1 = nn.Linear(64, 128) + self.actor = nn.Linear(128, num_of_actions) + self.critic = nn.Linear(128,1) + + + def forward(self, x, h): + + out, h = self.gru(x, h) + if out.dim()==3: + out = out.squeeze(0) + + out = torch.relu(out) + out = self.fc1(out) + out = torch.relu(out) + + a_probs = self.actor(out) + a_probs = F.softmax(a_probs,dim=-1) + + state_vals = self.critic(out) + + return a_probs, state_vals, h + + def init_hidden(self): + return torch.rand(1, self.gru_hidden_size).cuda() + + +""" """ +env = envirnoment(torch.tensor([1,1]), torch.tensor([199, 199]) ) +state_size = 21 + +gamma = 0.99 +agent = ActorCritic(state_size, 5).to(device) +optim = torch.optim.Adam(agent.parameters(), lr =0.01) +""" """ + +def train_agent_PPO(model, PPO_BUFFER, gamma, cycles = 5): + + # We create a copy of our model to do the training cycles with, at the end of this cycle we paste + # the weights from the copy to the original network + model_copy = ActorCritic(state_size,5).to(device) + model_copy.load_state_dict(model.state_dict()) + optim = torch.optim.Adam(model_copy.parameters(), lr =0.01) + criterion = nn.HuberLoss() + + # Calculate discounted rewards, we can get an estimate of a states future rewards by working backwards + # through the trajectory + discounted_rewards = [] + discounted_reward = 0 + for experience in reversed(PPO_BUFFER): + if experience[4] == True: # experience[4] is the done flags + discounted_reward = 0 + discounted_reward = experience[2] + gamma * discounted_reward # experience[2] is the states + discounted_rewards.insert(0, discounted_reward) + + """ + Creating a batch and reshaping + """ + discounted_reward_tensor = torch.tensor([discounted_rewards],dtype=torch.float).view(64,1).to(device) + #print("DISCOUNT TENSOR" + str(discounted_reward_tensor.shape)) + + state_tensor = torch.cat([t[0] for t in PPO_BUFFER]) + state_tensor = state_tensor.view(-1,21).to(device).float().unsqueeze(0) + #print("STATE TENSOR" + str(state_tensor.shape)) + + action_tensor = torch.tensor([t[1] for t in PPO_BUFFER]) + action_tensor = action_tensor.view(-1,1).to(device) + #print("ACTION TENSOR" + str(action_tensor.shape)) + + next_state = torch.cat([t[3] for t in PPO_BUFFER]) + next_state = next_state.view(-1,21).to(device).float() + #print("NEXT STATE TENSOR" + str(next_state.shape)) + + + hidden_state = torch.cat([t[5] for t in PPO_BUFFER]).unsqueeze(0).to(device) + #print(hidden_state.shape) + #hidden_state = hidden_state.view(2,-1,64).to(device).float() + #print("HIDDEN TENSOR" + str(hidden_state.shape)) + + + #torch.autograd.set_detect_anomaly(True) + + """ + PPO lets us do multiple updates on the same training data + """ + for i in range(cycles): + + current_policy, _, _ = model(state_tensor.float().to(device),hidden_state) + + new_policy, values, _ = model_copy(state_tensor.float().to(device),hidden_state) + advantages = discounted_reward_tensor - values + + safety = 0.0001 + + # Find the ratio of -the probability of choosing an action given a state- + # between the original policy and the model_copys policy, this is 1 in the first cycle + # since they are the same network + ratios = (torch.gather(new_policy, 1,action_tensor)+ safety)/ (torch.gather(current_policy, 1,action_tensor)+safety) + + # Clippled actor loss + actor_loss = -1 * torch.min(advantages * ratios, torch.clamp(ratios, 1-0.2, 1+0.2)*advantages) + # Critic loss calculated as Huberloss between the current state values and the "future rewards" + critic_loss = criterion(discounted_reward_tensor,values.squeeze(0)) + + optim.zero_grad() + total_loss = actor_loss + critic_loss + total_loss = total_loss.mean() + total_loss.backward() + + optim.step() + #print("Example actor loss:", str(actor_loss.mean())) + #print("Example critic loss:", str(critic_loss.mean())) + + model.load_state_dict(model_copy.state_dict()) + + +def train_loop(model, env, episodes=1000, ep_length= 100, target_update_timing = 64, epsilon_start= 0.9, epsilon_last = 0.2, epsilon_step = 0.0001): + + model.train() + steps = 0 + + """ + We use a PPO buffer to save a trajectory of length "target update timing" we are doing on-policy training, so we + empty the PPO buffer after a training cycle completes. + """ + + PPO_BUFFER = [] + + """ We use an epsilon greedy algorithm """ + epsilon = epsilon_start + + for ep in range(episodes): + print(ep) + print("Epsilon: " +str(epsilon)) + + ep_reward = 0 + + if(ep == 10): + ep_length=600 + + state,_,_ = env.reset() + + h = model.init_hidden() # hidden state for the GRU + + + + for i in range(ep_length): + #print(env.location) + steps += 1 + epsilon = max(epsilon_last, epsilon_start - 0.0003 *steps) + rr = torch.rand(1) + action = 0 + + if rr < epsilon: # pick a random action, we still use the model to get the next hidden state + _, values,next_h = model(state.float().to(device),h) + action = torch.randint(0,5,(1,)) + + else: # Pick the action according to our policy + #print(main_net(state)) + a_porbs, value,next_h = model(state.float().to(device),h) + action = torch.argmax(a_porbs) + #print(action) + + + next_state, reward, done, _, _ = env.take_action(action) # step through the environment + + ep_reward += reward + + PPO_BUFFER.append([state, action, reward, next_state, done, h.detach()]) + + state = next_state + h = next_h + + if steps !=0 and steps % target_update_timing == 0: # Train the model + train_agent_PPO(model,PPO_BUFFER,gamma) + PPO_BUFFER = [] + + if i == ep_length-1: + print(env.location) + print(str(ep_reward/i)) + + if done: + print(env.location) + print(str(ep_reward/i)) + break + print(state) + +train_loop(agent,env) \ No newline at end of file diff --git a/DEEPQsolution.py b/DEEPQsolution.py new file mode 100644 index 0000000000000000000000000000000000000000..63b755d9f8dcd474053544f44fb99537ecb06e13 --- /dev/null +++ b/DEEPQsolution.py @@ -0,0 +1,268 @@ +import torch +import read_maze as maze +from enum import Enum +import torch.optim +import torch.nn as nn +import numpy as np +import random +import gym +from torch.nn import functional as F +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +maze.load_maze() + +""" +State is defined as the the x-y coordinate of the agent concatenated with its surrounding information(flattened) and +timestep + +Actions are [Nothing = 0, Up = 1, Right = 2, Down = 3, Left = 4] +An action that reduces the straight line distance to to the goal(Down and Right) gives a reward of +1, actions that +increase it give a reward of -1. If an action doesnt change the agents location, the reward is -0.5. A time penalty of +-0.001 is applied for each timestep. + +""" +class envirnoment(): + + def __init__(self,start, end): + self.start = torch.clone(start) + self.location = torch.clone(start) + self.end = torch.clone(end) + self.around = torch.from_numpy(maze.get_local_maze_information(*self.location)) + self.action_space = 5 + self.time_step = torch.tensor([0]) + + + def reset(self): + self.location = torch.clone(self.start) + self.around = torch.from_numpy(maze.get_local_maze_information(*self.location)) + self.time_step = torch.tensor([0]) + + state = torch.cat([self.location.unsqueeze(0), self.around.view(-1,18), self.time_step.unsqueeze(0)],dim=1) + return state + + + def take_action(self, action): + # NEXT IS WHERE THE AGENT WILL LAND AFTER PERFORMING AN ACTION + next = torch.clone(self.location) + reward = 0 + + if action == 1: + if self.around[1][0][0] == 1:# and self.around[0][1][1] == 0: + next[1] += -1 + reward -= 1 + #print('up') + + elif action == 2: + + if self.around[2][1][0] == 1:# and self.around[1][2][1] == 0: + next[0] += 1 + reward += 1 + #print('right') + + elif action == 3: + if self.around[1][2][0] == 1:# and self.around[2][1][1] == 0: + next[1] += 1 + reward += 1 + #print('down') + + elif action == 4: + + if self.around[0][1][0] == 1:# and self.around[1][0][1] == 0: + next[0] += -1 + reward -= 1 + #print('left') + + # IF THE AGENT HAS NOT CHANGED LOCATION THIS ACTION + if torch.equal(self.location, next): + reward = -0.5 + + self.location = torch.clone(next) + self.around = torch.from_numpy(maze.get_local_maze_information(*self.location)) + + # IF THE AGENT SUCCESSFULLY GETS TO THE GOAL + done = False + if torch.equal(self.location, self.end): + reward = 100 + done = True + + # SUBTRACT TIME BASED PENALTY FROM REWARD + reward -= self.time_step * 0.001 + self.time_step +=1 + + state = torch.cat([self.location.unsqueeze(0), self.around.view(-1,18), self.time_step.unsqueeze(0)],dim=1) + return state, reward, done + + +""" +A replay buffer is used to keep the last few [state action next_state reward] steps, after the buffer fills up to max +capacity for the first time we sample from the buffer for training data each timestep. + +""" + +class ReplayBuffer(): + def __init__(self,batch_size = 16, replay_capacity =100): + self.memory = [] + self.batch_size = batch_size + self.replay_capacity = replay_capacity + + def __len__(self): + return len(self.memory) + + def save(self, experience): # experience: [state, action, new_state, reward] + if len(self.memory)> self.replay_capacity: + self.memory.pop(0) + self.memory.append(experience) + + def sample(self): + if len(self.memory)< self.batch_size: + return random.sample(self.memory, len(self.memory)) + return random.sample(self.memory, self.batch_size) + + + +""" +A simple feedforward network +""" +class DeepQ(nn.Module): + + def __init__(self, in_size=21, out_size=5): + super(DeepQ, self).__init__() + self.fc1 = nn.Linear(21, 32) + self.fc2 = nn.Linear(32,64) + self.fc3 = nn.Linear(64,out_size) + + def forward(self, x): + + out = self.fc1(x) + out = torch.relu(out) + out = self.fc2(out) + out = torch.relu(out) + out = self.fc3(out) + + return out + + +main_net = DeepQ().to(device) +target_net = DeepQ().to(device) +target_net.load_state_dict(main_net.state_dict()) + +optim = torch.optim.Adam(main_net.parameters(), lr =0.001) +replay_buffer = ReplayBuffer() +criterion = nn.HuberLoss() +gamma = 0.99 +state_size = 21 + + +def train_network(): + + if len(replay_buffer) < replay_buffer.replay_capacity-1: + return + + exp_batch = replay_buffer.sample() + """ + Creating a batch from samples and rearranging sizes + """ + + state_tensor = torch.cat([t[0] for t in exp_batch]) + state_tensor = state_tensor.view(-1,state_size).to(device).float() + + action_tensor = torch.tensor([t[1] for t in exp_batch]) + action_tensor = action_tensor.view(-1,1).to(device) + + next_state = torch.cat([t[2] for t in exp_batch]) + next_state = next_state.view(-1,state_size).to(device).float() + + reward_tensor = torch.tensor([t[3] for t in exp_batch]) + reward_tensor = reward_tensor.view(-1,1).to(device) + + state_values = main_net(state_tensor) + state_values = state_values.squeeze(1) + + """ + Querying the target network for to estimate the future rewards + """ + + future_values = target_net(next_state) + future_values = future_values.squeeze(1) + + max_future_values,_ = torch.max(future_values,dim=1) + max_future_values = max_future_values.unsqueeze(1) + + """ + new_value is what the current value will be updated towards + """ + new_value = reward_tensor + gamma * max_future_values + + current_value = torch.gather(state_values,1,action_tensor) + loss = criterion(current_value,new_value) + + optim.zero_grad() + loss.backward() + optim.step() + + +""" +We follow an epsilon-greedy algorithm, epsilon decreases from 0.9 down to 0.05 linearly each timestep. +""" +epsilon_start = 0.9 +epsilon_last = 0.05 + + +""" +target network updates its weights to be equal tom the main netwrok every 64 timesteps +""" +target_update_timing = 64 + +episodes = 10000 + +env = envirnoment(torch.tensor([1,1]), torch.tensor([199, 199]) ) + +epsilon = epsilon_start +ep_length = 500 +steps = 0 +for ep in range(episodes): + state = env.reset() + print(ep) + ep_reward = 0 + print(epsilon) + + + for i in range(ep_length): + steps += 1 + + epsilon = max(epsilon_last, epsilon_start - 0.0005 *steps) + + rr = torch.rand(1) + + if rr < epsilon: # choose a random action + action = torch.randint(0,4,(1,)) + + else: # follow the current policy for the action + with torch.no_grad(): + action_scores = main_net(state.float().to(device)).detach() + action = torch.argmax(action_scores) + + + next_state, reward, done = env.take_action(action) + + + ep_reward += reward + replay_buffer.save( [state, action, next_state.detach(),reward] ) #save the transition to the buffer + state = torch.clone(next_state).detach() + + train_network() # train main network + + if steps % target_update_timing == 0: + target_net.load_state_dict(main_net.state_dict()) + + #print(reward) + if done: + print("SUCCESSS"+ str(state)) + #print(str(ep_reward/i)) + break + + if i == ep_length-1: + print(env.location) + print(str(ep_reward/i)) +