m,n=list(map(int,input().split()))
#n=int(input())
c=m*n
sum=0
for i in range(2,int(c**0.5)+1):
    if c%i==0 and i<=c//i and i!=min(m,n):
        sum=sum+1
print(sum)

/**************************************************************
	Problem: 1318
	User: admin
	Language: Python
	Result: Accepted
	Time:1245 ms
	Memory:34480 kb
****************************************************************/