Я запускаю этот код ( https://github.com/ayu-22/BPPNet-Back-Projected-Pyramid-Network/blob/master/Single_Image_Dehazing.ipynb ) в пользовательском наборе данных, но я сталкиваюсь с этой ошибкой.RuntimeError: one of the variables needed for gradient computation has been modified by an in place operation: [torch. cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Пожалуйста, обратитесь к приведенной выше ссылке на код, чтобы уточнить, где возникает ошибка.
Я запускаю эту модель в пользовательском наборе данных, часть загрузчика данных вставлена ниже.
import torchvision.transforms as transforms
train_transform = transforms.Compose([
transforms.Resize((256,256)),
#transforms.RandomResizedCrop(256),
#transforms.RandomHorizontalFlip(),
#transforms.ColorJitter(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
])
class Flare(Dataset):
def __init__(self, flare_dir, wf_dir,transform = None):
self.flare_dir = flare_dir
self.wf_dir = wf_dir
self.transform = transform
self.flare_img = os.listdir(flare_dir)
self.wf_img = os.listdir(wf_dir)
def __len__(self):
return len(self.flare_img)
def __getitem__(self, idx):
f_img = Image.open(os.path.join(self.flare_dir, self.flare_img[idx])).convert("RGB")
for i in self.wf_img:
if (self.flare_img[idx].split('.')[0][4:] == i.split('.')[0]):
wf_img = Image.open(os.path.join(self.wf_dir, i)).convert("RGB")
break
f_img = self.transform(f_img)
wf_img = self.transform(wf_img)
return f_img, wf_img
flare_dir = '../input/flaredataset/Flare/Flare_img'
wf_dir = '../input/flaredataset/Flare/Without_Flare_'
flare_img = os.listdir(flare_dir)
wf_img = os.listdir(wf_dir)
wf_img.sort()
flare_img.sort()
print(wf_img[0])
train_ds = Flare(flare_dir, wf_dir,train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_ds,
batch_size=BATCH_SIZE,
shuffle=True)
Чтобы лучше понять класс набора данных, вы можете сравнить мой класс набора данных со ссылкой, вставленной выше.
Решение проблемы
Ваш код застрял в так называемом «обратном распространении» вашей сети GAN.
То, что вы определили, должно следовать вашему обратному графику:
def backward(self, unet_loss, dis_loss):
dis_loss.backward(retain_graph = True)
self.dis_optimizer.step()
unet_loss.backward()
self.unet_optimizer.step()
Таким образом, в вашем обратном графике вы dis_loss
сначала распространяете то, что представляет собой комбинацию дискриминатора и состязательной потери, а затем вы распространяете то, unet_loss
что является комбинацией UNet
, SSIM
но связано с выходными потерями дискриминатора ContentLoss
. unet_loss
Таким образом, pytorch сбит с толку и выдает эту ошибку, когда вы выполняете шаг оптимизатора, dis_loss
прежде чем даже сохранять обратный график unet_loss
, и я бы рекомендовал вам изменить код следующим образом:
def backward(self, unet_loss, dis_loss):
dis_loss.backward(retain_graph = True)
unet_loss.backward()
self.dis_optimizer.step()
self.unet_optimizer.step()
И с этого начнется ваше обучение! но вы можете поэкспериментировать со своим retain_graph=True
.
И отличная работа над BPPNet Work.
Комментариев нет:
Отправить комментарий