from fastai.vision.all import *
matplotlib.rc('image', cmap='Greys')
trgts = tensor([0, 1, 0])
prds = tensor([0.8, 0.4, 0.2], [0.6, 0.2, 0.5])
trgts, prds
torch.where(trgts == 1, 1-prds, prds).mean()
trgts = tensor([0])
prds = tensor([0.4, 0.3],[0.1, 0.2])
trgts, prds
torch.where(trgts == 1, 1-prds, prds).mean()
trgts = tensor([1])
prds = tensor([0.4, 0.3],[0.1, 0.2])
trgts, prds
torch.where(trgts == 1, 1-prds, prds).mean()
def ret_where(a, b):
c = torch.where(a == 1, 1-b, b)
return c, c.mean()
trgts = tensor([0], [1])
prds = tensor([0.4, 0.3],[0.1, 0.2])
trgts, prds
ret_where(trgts, prds)
trgts = tensor([0], [1], [0])
prds = tensor([0.4, 0.3],[0.1, 0.2], [0.1, 0.6] )
trgts, prds
a = torch.where(trgts == 1, 1-prds, prds)
a, a.mean()
ret_where(trgts, prds)