tensor操作

写PR作业的时候出了诡异bug,train_acc很正常,valid_acc一直是个很小的数字。检查发现问题出在:

1
acc=sum(pred==torch.chunk(torch.nonzero(My_valid),2,1)[1].squeeze())

其中My_valid为one-hot编码,torch.chunk(torch.nonzero(My_valid),2,1)[1].squeeze()操作即one-hot转数字,nonzero()转为 [序号,标签] 的二元tensor,使用chunk()squeeze()转为一维tensor

sum里面的内容为tensor,直接求sum出错(很奇怪求train_acc也是这么写的就没错,小样本没啥问题,大样本结果会变小???)

可以改为:

1
acc=tensor.sum(pred==torch.chunk(torch.nonzero(My_valid),2,1)[1].squeeze())

Or

1
acc=sum(pred==torch.chunk(torch.nonzero(My_valid),2,1)[1].squeeze().cpu().numpy())