GOOGLE ADS

воскресенье, 10 апреля 2022 г.

Я столкнулся с ошибкой вычисления градиента на месте

Я запускаю этот код ( https://github.com/ayu-22/BPPNet-Back-Projected-Pyramid-Network/blob/master/Single_Image_Dehazing.ipynb ) в пользовательском наборе данных, но я сталкиваюсь с этой ошибкой.
RuntimeError: one of the variables needed for gradient computation has been modified by an in place operation: [torch. cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Сообщение об ошибке

Пожалуйста, обратитесь к приведенной выше ссылке на код, чтобы уточнить, где возникает ошибка.

Я запускаю эту модель в пользовательском наборе данных, часть загрузчика данных вставлена ​​ниже.

 import torchvision.transforms as transforms
train_transform = transforms.Compose([
transforms.Resize((256,256)),
#transforms.RandomResizedCrop(256),
#transforms.RandomHorizontalFlip(),
#transforms.ColorJitter(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
])
class Flare(Dataset):
def __init__(self, flare_dir, wf_dir,transform = None):
self.flare_dir = flare_dir
self.wf_dir = wf_dir
self.transform = transform
self.flare_img = os.listdir(flare_dir)
self.wf_img = os.listdir(wf_dir)

def __len__(self):
return len(self.flare_img)
def __getitem__(self, idx):
f_img = Image.open(os.path.join(self.flare_dir, self.flare_img[idx])).convert("RGB")
for i in self.wf_img:
if (self.flare_img[idx].split('.')[0][4:] == i.split('.')[0]):
wf_img = Image.open(os.path.join(self.wf_dir, i)).convert("RGB")
break
f_img = self.transform(f_img)
wf_img = self.transform(wf_img)

return f_img, wf_img
flare_dir = '../input/flaredataset/Flare/Flare_img'
wf_dir = '../input/flaredataset/Flare/Without_Flare_'
flare_img = os.listdir(flare_dir)
wf_img = os.listdir(wf_dir)
wf_img.sort()
flare_img.sort()
print(wf_img[0])
train_ds = Flare(flare_dir, wf_dir,train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_ds,
batch_size=BATCH_SIZE,
shuffle=True)

Чтобы лучше понять класс набора данных, вы можете сравнить мой класс набора данных со ссылкой, вставленной выше.


Решение проблемы

Ваш код застрял в так называемом «обратном распространении» вашей сети GAN.

То, что вы определили, должно следовать вашему обратному графику:

def backward(self, unet_loss, dis_loss):
dis_loss.backward(retain_graph = True)
self.dis_optimizer.step()
unet_loss.backward()
self.unet_optimizer.step()

Таким образом, в вашем обратном графике вы dis_lossсначала распространяете то, что представляет собой комбинацию дискриминатора и состязательной потери, а затем вы распространяете то, unet_lossчто является комбинацией UNet, SSIMно связано с выходными потерями дискриминатора ContentLoss. unet_lossТаким образом, pytorch сбит с толку и выдает эту ошибку, когда вы выполняете шаг оптимизатора, dis_lossпрежде чем даже сохранять обратный график unet_loss, и я бы рекомендовал вам изменить код следующим образом:

def backward(self, unet_loss, dis_loss):
dis_loss.backward(retain_graph = True)
unet_loss.backward()
self.dis_optimizer.step()
self.unet_optimizer.step()

И с этого начнется ваше обучение! но вы можете поэкспериментировать со своим retain_graph=True.

И отличная работа над BPPNet Work.

Комментариев нет:

Отправить комментарий

Laravel Datatable addColumn returns ID of one record only

Я пытаюсь использовать Yajra Datatable для интеграции DataTable на свой веб-сайт. Я смог отобразить таблицу, но столкнулся с проблемой. В по...