package de.renew.formalism.function;

import de.renew.expression.Function;
import de.renew.unify.Impossible;
import de.renew.unify.Tuple;
import de.renew.util.Value;

/**
 * Represents a basic function for mathematical or logical operations.
 */
public final class BasicFunction implements Function {
    private static final int ID_LOR = 1;
    private static final int ID_LAND = 2;
    private static final int ID_OR = 3;
    private static final int ID_AND = 4;
    private static final int ID_XOR = 5;
    private static final int ID_EQUAL = 6;
    private static final int ID_NEQUAL = 7;
    private static final int ID_LESS = 8;
    private static final int ID_GREATER = 9;
    private static final int ID_LESSEQUAL = 10;
    private static final int ID_GREATEREQUAL = 11;
    private static final int ID_SHL = 12;
    private static final int ID_SHR = 13;
    private static final int ID_SSHR = 14;
    private static final int ID_PLUS = 15;
    private static final int ID_MINUS = 16;
    private static final int ID_TIMES = 17;
    private static final int ID_DIVIDE = 18;
    private static final int ID_MOD = 19;

    /**
     * Represents the logical OR operation
     * (returns true if at least one operand is true)
     */
    public static final BasicFunction LOR = new BasicFunction(ID_LOR);

    /**
     * Represents the logical AND operation
     * (returns true only if both operands are true)
     */
    public static final BasicFunction LAND = new BasicFunction(ID_LAND);

    /**
     * Represents the bitwise OR operation
     * (returns true if at least one operand is true)
     */
    public static final BasicFunction OR = new BasicFunction(ID_OR);

    /**
     * Represents the bitwise AND operation
     * (returns true only if both operands are true)
     */
    public static final BasicFunction AND = new BasicFunction(ID_AND);

    /**
     * Represents the logical XOR operation
     * (returns true if exactly one operand is true, but not both)
     */
    public static final BasicFunction XOR = new BasicFunction(ID_XOR);

    /**
     * Represents the equality operation
     * (returns true if both operands are equal)
     */
    public static final BasicFunction EQUAL = new BasicFunction(ID_EQUAL);

    /**
     * Represents the inequality operation
     * (returns true if the operands are not equal)
     */
    public static final BasicFunction NEQUAL = new BasicFunction(ID_NEQUAL);

    /**
     * Represents the less-than operation
     * (returns true if the first operand is less than the second)
     */
    public static final BasicFunction LESS = new BasicFunction(ID_LESS);

    /**
     * Represents the greater-than operation
     * (returns true if the first operand is greater than the second)
     */
    public static final BasicFunction GREATER = new BasicFunction(ID_GREATER);

    /**
     * Represents the less-than-or-equal operation
     * (returns true if the first operand is less than or equal to the second)
     */
    public static final BasicFunction LESSEQUAL = new BasicFunction(ID_LESSEQUAL);

    /**
     * Represents the greater-than-or-equal operation
     * (returns true if the first operand is greater than or equal to the second)
     */
    public static final BasicFunction GREATEREQUAL = new BasicFunction(ID_GREATEREQUAL);

    /**
     * Represents the shift-left operation
     * (shifts the bits of the first operand to the left by the number of positions specified by the second operand)
     */
    public static final BasicFunction SHL = new BasicFunction(ID_SHL);

    /**
     * Represents the shift-right operation
     * (shifts the bits of the first operand to the right by the number of positions specified by the second operand)
     */
    public static final BasicFunction SHR = new BasicFunction(ID_SHR);

    /**
     * Represents the signed shift-right operation
     * (performs a right shift while preserving the sign of the first operand)
     */
    public static final BasicFunction SSHR = new BasicFunction(ID_SSHR);

    /**
     * Represents the addition operation
     * (returns the sum of the two operands)
     */
    public static final BasicFunction PLUS = new BasicFunction(ID_PLUS);

    /**
     * Represents the subtraction operation
     * (returns the difference between the two operands)
     */
    public static final BasicFunction MINUS = new BasicFunction(ID_MINUS);

    /**
     * Represents the multiplication operation
     * (returns the product of the two operands)
     */
    public static final BasicFunction TIMES = new BasicFunction(ID_TIMES);

    /**
     * Represents the division operation
     * (returns the quotient of the two operands)
     */
    public static final BasicFunction DIVIDE = new BasicFunction(ID_DIVIDE);

    /**
     * Represents the modulus operation
     * (returns the remainder when the first operand is divided by the second)
     */
    public static final BasicFunction MOD = new BasicFunction(ID_MOD);
    private static final String[] FUNC_NAMES;
    private static final int FUNC_COUNT = 20;

    /**
     * A unique identifier used to distinguish different functions in the system.
     */
    private final int _funcNum;

    static {
        FUNC_NAMES = new String[FUNC_COUNT];
        FUNC_NAMES[ID_LOR] = "BasicFunc(LOR)";
        FUNC_NAMES[ID_LAND] = "BasicFunc(LAND)";
        FUNC_NAMES[ID_OR] = "BasicFunc(OR)";
        FUNC_NAMES[ID_AND] = "BasicFunc(AND)";
        FUNC_NAMES[ID_XOR] = "BasicFunc(XOR)";
        FUNC_NAMES[ID_EQUAL] = "BasicFunc(EQUAL)";
        FUNC_NAMES[ID_NEQUAL] = "BasicFunc(NEQUAL)";
        FUNC_NAMES[ID_LESS] = "BasicFunc(LESS)";
        FUNC_NAMES[ID_GREATER] = "BasicFunc(GREATER)";
        FUNC_NAMES[ID_LESSEQUAL] = "BasicFunc(LESSEQUAL)";
        FUNC_NAMES[ID_GREATEREQUAL] = "BasicFunc(GREATEREQUAL)";
        FUNC_NAMES[ID_SHL] = "BasicFunc(SHL)";
        FUNC_NAMES[ID_SHR] = "BasicFunc(SHR)";
        FUNC_NAMES[ID_SSHR] = "BasicFunc(SSHR)";
        FUNC_NAMES[ID_PLUS] = "BasicFunc(PLUS)";
        FUNC_NAMES[ID_MINUS] = "BasicFunc(MINUS)";
        FUNC_NAMES[ID_TIMES] = "BasicFunc(TIMES)";
        FUNC_NAMES[ID_DIVIDE] = "BasicFunc(DIVIDE)";
        FUNC_NAMES[ID_MOD] = "BasicFunc(MOD)";
    }

    /**
     * Constructs a basic function with the given function ID.
     *
     * @param funcNum a function ID
     */
    private BasicFunction(int funcNum) {
        this._funcNum = funcNum;
    }

    /**
     * Converts a general object to a string.
     *
     * @param param an object
     * @return a string
     */
    private String generalToString(Object param) {
        if (param == null) {
            return "null";
        } else if (param instanceof Value) {
            return ((Value) param).value.toString();
        } else {
            return param.toString();
        }
    }

    /**
     * Applies this function to two objects.
     *
     * @param obj1 the first object
     * @param obj2 the second object
     * @return the result when this function is applied to the objects
     * @throws Impossible if the objects cannot be applied to the objects
     */
    private Value handleObjects(Object obj1, Object obj2) throws Impossible {
        Boolean result = switch (_funcNum) {
            case ID_EQUAL -> obj1 == obj2;
            case ID_NEQUAL -> obj1 != obj2;
            default -> throw new Impossible();
        };
        return new Value(result);
    }

    /**
     * Applies this function to two booleans.
     *
     * @param obj1 the first boolean
     * @param obj2 the second boolean
     * @return the result when this function is applied to the booleans
     * @throws Impossible if the objects cannot be applied to the booleans
     */
    private Value handleBooleans(Boolean obj1, Boolean obj2) throws Impossible {
        Boolean result = switch (_funcNum) {
            case ID_LOR -> obj1 || obj2;
            case ID_LAND -> obj1 && obj2;
            case ID_OR -> obj1 | obj2;
            case ID_AND -> obj1 & obj2;
            case ID_XOR -> obj1 ^ obj2;
            case ID_EQUAL -> obj1 == obj2;
            case ID_NEQUAL -> obj1 != obj2;
            default -> throw new Impossible();
        };
        return new Value(result);
    }

    /**
     * Applies this function to two doubles.
     *
     * @param obj1 the first double
     * @param obj2 the second double
     * @return the result when this function is applied to the doubles
     * @throws Impossible if the objects cannot be applied to the doubles
     */
    private Value handleDoubles(double obj1, double obj2) throws Impossible {
        Object result = switch (_funcNum) {
            case ID_EQUAL -> obj1 == obj2;
            case ID_NEQUAL -> obj1 != obj2;
            case ID_LESS -> obj1 < obj2;
            case ID_GREATER -> obj1 > obj2;
            case ID_LESSEQUAL -> obj1 <= obj2;
            case ID_GREATEREQUAL -> obj1 >= obj2;
            case ID_PLUS -> obj1 + obj2;
            case ID_MINUS -> obj1 - obj2;
            case ID_TIMES -> obj1 * obj2;
            case ID_DIVIDE -> obj1 / obj2;
            case ID_MOD -> obj1 % obj2;
            default -> throw new Impossible();
        };
        return new Value(result);
    }

    /**
     * Applies this function to two floats.
     *
     * @param obj1 the first float
     * @param obj2 the second float
     * @return the result when this function is applied to the floats
     * @throws Impossible if the objects cannot be applied to the floats
     */
    private Value handleFloats(float obj1, float obj2) throws Impossible {
        Value result = handleDoubles(obj1, obj2);
        if (result.value instanceof Double d) {
            return new Value(d.floatValue());
        }
        return result;
    }

    /**
     * Applies this function to two longs.
     *
     * @param obj1 the first long
     * @param obj2 the second long
     * @return the result when this function is applied to the longs
     * @throws Impossible if the objects cannot be applied to the longs
     */
    private Value handleLongs(long obj1, long obj2) throws Impossible {
        Object result = switch (_funcNum) {
            case ID_OR -> obj1 | obj2;
            case ID_AND -> obj1 & obj2;
            case ID_XOR -> obj1 ^ obj2;
            case ID_EQUAL -> obj1 == obj2;
            case ID_NEQUAL -> obj1 != obj2;
            case ID_LESS -> obj1 < obj2;
            case ID_GREATER -> obj1 > obj2;
            case ID_LESSEQUAL -> obj1 <= obj2;
            case ID_GREATEREQUAL -> obj1 >= obj2;
            case ID_PLUS -> obj1 + obj2;
            case ID_MINUS -> obj1 - obj2;
            case ID_TIMES -> obj1 * obj2;
            case ID_DIVIDE -> obj1 / obj2;
            case ID_MOD -> obj1 % obj2;
            default -> throw new Impossible();
        };
        return new Value(result);
    }

    /**
     * Applies this function to two integers.
     *
     * @param obj1 the first integer
     * @param obj2 the second integer
     * @return the result when this function is applied to the integers
     * @throws Impossible if the objects cannot be applied to the integers
     */
    private Value handleInts(int obj1, int obj2) throws Impossible {
        Value result = handleLongs(obj1, obj2);
        if (result.value instanceof Long l) {
            return new Value(l.intValue());
        }
        return result;
    }

    /**
     * Handles shift operations on two numbers.
     *
     * @param obj1 the first number
     * @param obj2 the second number
     * @return the result when this function is applied to the numbers
     * @throws Impossible if the objects cannot be applied to the numbers
     */
    private Value handleShifts(Object obj1, Object obj2) throws Impossible {
        long right = ((Number) obj2).longValue();

        // Since the result of the shift operation depends on the type of the left variable, the case needs to be duplicated
        if (obj1 instanceof Long) {
            long left = ((Number) obj1).longValue();
            Object result = switch (_funcNum) {
                case ID_SHL -> left << right;
                case ID_SHR -> left >> right;
                case ID_SSHR -> left >>> right;
                default -> throw new Impossible();
            };
            return new Value(result);
        }

        int left = ((Number) obj1).intValue();
        Object result = switch (_funcNum) {
            case ID_SHL -> left << right;
            case ID_SHR -> left >> right;
            case ID_SSHR -> left >>> right;
            default -> throw new Impossible();
        };
        return new Value(result);
    }

    @Override
    public Object function(Object param) throws Impossible {
        Tuple tuple = (Tuple) param;
        if (tuple.getArity() != 2) {
            throw new Impossible();
        }
        Object obj1 = tuple.getComponent(0);
        Object obj2 = tuple.getComponent(1);

        // If one of the two objects is a string, I don't even need
        // to unwrap values.
        if (obj1 instanceof String || obj2 instanceof String) {
            if (_funcNum == ID_PLUS) {
                return generalToString(obj1) + generalToString(obj2);
            }
        }

        // Ok, there are no Strings involved.
        // Make sure both operands are values or none at all.
        if (obj1 instanceof Value ^ obj2 instanceof Value) {
            throw new Impossible();
        }

        // Let's consider equality operations on objects.
        if (!(obj1 instanceof Value)) {
            return handleObjects(obj1, obj2);
        }

        // Both arguments are values.
        obj1 = ((Value) obj1).value;
        obj2 = ((Value) obj2).value;

        // Booleans are next.
        if (obj1 instanceof Boolean || obj2 instanceof Boolean) {
            if (obj1 instanceof Boolean && obj2 instanceof Boolean) {
                return handleBooleans((Boolean) obj1, (Boolean) obj2);
            }
            throw new Impossible();
        }

        // I will replace characters by integers, so that
        // there are only numbers left: byte, short, integer,
        // long, float, double.
        if (obj1 instanceof Character) {
            obj1 = (int) (Character) obj1;
        }
        if (obj2 instanceof Character) {
            obj2 = (int) (Character) obj2;
        }

        // Now we proceed with the unwrapped numbers.
        if (obj1 instanceof Double || obj2 instanceof Double) {
            return handleDoubles(((Number) obj1).doubleValue(), ((Number) obj2).doubleValue());
        }

        if (obj1 instanceof Float || obj2 instanceof Float) {
            return handleFloats(((Number) obj1).floatValue(), ((Number) obj2).floatValue());
        }

        // We must redirect calls to the shift operators, because the
        // result type of these operators depends only on the type
        // of the left argument.
        if (_funcNum == ID_SHL || _funcNum == ID_SHR || _funcNum == ID_SSHR) {
            return handleShifts(obj1, obj2);
        }

        if (obj1 instanceof Long || obj2 instanceof Long) {
            return handleLongs(((Number) obj1).longValue(), ((Number) obj2).longValue());
        }

        return handleInts(((Number) obj1).intValue(), ((Number) obj2).intValue());
    }

    @Override
    public String toString() {
        String result = null;
        try {
            result = FUNC_NAMES[_funcNum];
        } catch (ArrayIndexOutOfBoundsException e) {
            // result = null;
        }
        if (result == null) {
            result = "BasicFunc(<<<INVALID!>>>)";
        }
        return result;
    }
}