Search code examples
pythonpytorchtensor

Count how many elements of one Tensor exists in another Tensor


I have two 1D tensors:

A = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
B = torch.tensor([2, 5, 6, 8, 12, 15, 16])

The tensors are extremely huge, different in lengths, and the values are neighter in sequence nor sorted.

I want to get the count of the number of elements of B that (i) exists in A (ii) do not exist in A. So, the output will be:

Exists: 4
Do not exist: 3

I have tried:

exists = torch.eq(A,B).sum().item()
not_exist = torch.numel(B) - exists

But this gives the error:

RuntimeError: The size of tensor a (10) must match the size of tensor b (7) at non-singleton dimension 0

The following approach works but it involves creating a boolean tensor first and then summing up the true elements. Is it efficient for very large tensors?

exists = np.isin(A,B).sum()
not_exist = torch.numel(B) - exists

Is there any better or more efficient approach?


Solution

  • Try the following : import torch

    A = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    B = torch.tensor([2, 5, 6, 8, 12, 15, 16])
    
    setA = set(A.numpy())
    setB = set(B.numpy())
    
    intersection = setA & setB
    difference = setB - setA
    
    exists = len(intersection)
    not_exist = len(difference)
    
    print(f"Exists: {exists}")
    print(f"Do not exist: {not_exist}")
    

    Update :

    You can sticki with native PyTorch methods like broadcasting. This approach might be more memory-intensive, but can be more efficient for large tensors, especially when using GPU acceleration.

    import torch
    
    A = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    B = torch.tensor([2, 5, 6, 8, 12, 15, 16])
    
    
    matches_per_element = comparison_matrix.sum(dim=1)
    
    exists = (matches_per_element > 0).sum().item()
    not_exist = len(B) - exists
    
    print(f"Exists: {exists}")
    print(f"Do not exist: {not_exist}")