I had a lot of issues in understanding torch.where, so tried to deconstruct its working

from fastai.vision.all import *
matplotlib.rc('image', cmap='Greys')

Element-wise operation

0, 1, 0 will be compated with row one of prds and then row 2 and so on...

trgts  = tensor([0, 1, 0])
prds   = tensor([0.8, 0.4, 0.2], [0.6, 0.2, 0.5])
trgts, prds
(tensor([0, 1, 0]),
 tensor([[0.8000, 0.4000, 0.2000],
         [0.6000, 0.2000, 0.5000]]))
torch.where(trgts == 1, 1-prds, prds).mean()
tensor(0.5833)

Single tensor

0 will be compared with row 1 of prds and then row 2 and so on...

trgts  = tensor([0])
prds   = tensor([0.4, 0.3],[0.1, 0.2])
trgts, prds
(tensor([0]),
 tensor([[0.4000, 0.3000],
         [0.1000, 0.2000]]))
torch.where(trgts == 1, 1-prds, prds).mean()
tensor(0.2500)
trgts  = tensor([1])
prds   = tensor([0.4, 0.3],[0.1, 0.2])
trgts, prds
(tensor([1]),
 tensor([[0.4000, 0.3000],
         [0.1000, 0.2000]]))
torch.where(trgts == 1, 1-prds, prds).mean()
tensor(0.7500)

Create a function

def ret_where(a, b):
    c = torch.where(a == 1, 1-b, b)
    return c, c.mean()

This is an interesting case

trgts  = tensor([0], [1])
prds   = tensor([0.4, 0.3],[0.1, 0.2])
trgts, prds
(tensor([[0],
         [1]]),
 tensor([[0.4000, 0.3000],
         [0.1000, 0.2000]]))

First target tensor with first prediction tensor row, second target tensor row with second prediction row

ret_where(trgts, prds)
(tensor([[0.4000, 0.3000],
         [0.9000, 0.8000]]),
 tensor(0.6000))
trgts  = tensor([0], [1], [0])
prds   = tensor([0.4, 0.3],[0.1, 0.2], [0.1, 0.6] )
trgts, prds
(tensor([[0],
         [1],
         [0]]),
 tensor([[0.4000, 0.3000],
         [0.1000, 0.2000],
         [0.1000, 0.6000]]))
a = torch.where(trgts == 1, 1-prds, prds)
a, a.mean()
(tensor([[0.4000, 0.3000],
         [0.9000, 0.8000],
         [0.1000, 0.6000]]),
 tensor(0.5167))

Entire first target is broadcasted to first prediction, second target to second row and so on...

ret_where(trgts, prds)
(tensor([[0.4000, 0.3000],
         [0.9000, 0.8000],
         [0.1000, 0.6000]]),
 tensor(0.5167))