dnc-with-demon/Dataset/Bitmap/BitmapTask.py
2022-11-05 14:59:40 -07:00

83 lines
2.6 KiB
Python

# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import torch
import torch.nn.functional as F
from Visualize.BitmapTask import visualize_bitmap_task
from Utils import Visdom
from Utils import universal as U
import numpy as np
class BitmapTask(torch.utils.data.Dataset):
def __init__(self):
super(BitmapTask, self).__init__()
self._img = Visdom.Image("preview")
def set_dump_dir(self, dir):
self._img.set_dump_dir(dir)
def __len__(self):
return 0x7FFFFFFF
def visualize_preview(self, data, net_output):
img = visualize_bitmap_task(
data["input"], [data["output"], U.sigmoid(net_output)]
)
self._img.draw(img)
def loss(self, net_output, target):
return F.binary_cross_entropy_with_logits(
net_output, target, reduction="sum"
) / net_output.size(0)
def accuracy(self, net_output, target):
return F.binary_cross_entropy_with_logits(
net_output, target, reduction="sum"
) / net_output.size(0)
def demon_loss(self, net_output, target, saved_actions, device):
"""
computes the loss for the demon
:param net_output:
:param target:
:param saved_actions:
:return:
"""
net_output = net_output.detach()
loss = F.binary_cross_entropy_with_logits(
net_output, target, reduction="none"
).sum(dim=-1)
policy_losses = [] # list to save actor (policy) loss
discount_factor = 0.99
for i in range(0, loss.size(1)): # computing expected total reward
discount_vector = torch.from_numpy(np.array([np.power(discount_factor,i) for i in range(loss.size(1)-i)])).to(device)
policy_losses.append(((saved_actions[i].log_prob).squeeze(1) * (discount_vector*loss[:, i:]).mean(dim=1)))
demon_loss = torch.stack(policy_losses).mean(dim=0)/loss.size(1)
return demon_loss
def state_dict(self):
return {}
def load_state_dict(self, state):
pass