x = input().split()
m = int(x[0])
n = int(x[1])
if m > n:
print((m - n) * m)
if m < n:
print((n - m) * n)
/**************************************************************
Problem: 1324
User: admin
Language: Python
Result: Accepted
Time:92 ms
Memory:34244 kb
****************************************************************/