import {
    Color,
    Matrix4,
    Mesh,
    MeshPhongMaterial,
    Object3D,
    Scene,
    Vector3,
    SphereGeometry,
} from 'three';
import { TransformControls } from 'three/examples/jsm/controls/TransformControls.js';
import { URDFJoint, URDFLink, URDFMimicJoint, URDFRobot } from 'urdf-loader';
import {
    convertDegreeToRad,
    limit,
} from '../../shared/services/number-translate.service';
import { JointType } from 'src/app/shared/enums/joint-type.enum';

export class InverseKinematicOwn {
    ////////////////////////////////////////////////////
    robot: URDFRobot;
    controls: TransformControls;
    scene: Scene;
    endOfChain: URDFJoint;

    chain: URDFJoint[];
    orignialChainWorldMatrix: Matrix4[];
    relativeOgChain: Matrix4[];
    currentAngles: number[];
    //solving min distance
    distanceThreshold = 0.01;
    //stepsize
    samplingDistance = convertDegreeToRad(0.5);
    learningrate: number = 0.125;
    targetSphere: Mesh;
    realigning = false;
    // Performance Calculation
    lastVector: Vector3;
    ///////////////////////////////////////////////////

    constructor(endOfChain: URDFJoint, robot: URDFRobot, scene: Scene) {
        this.endOfChain = endOfChain;
        this.robot = robot;
        this.scene = scene;

        this.init();
    }

    private init() {
        this.chain = [];
        this.orignialChainWorldMatrix = [];
        this.relativeOgChain = [];
        this.currentAngles = [];
        this.chain = this.createKinematicChain(this.endOfChain);

        let angles = [];
        for (let joint of this.chain) {
            this.orignialChainWorldMatrix.push(joint.matrixWorld);
            angles.push(joint.limit.upper?.valueOf());
            this.currentAngles.push(joint.jointValue[0]?.valueOf());
        }

        this.computeRelativeChain();
        this.setupTargetControl(this.endOfChain);

        this.forwardKinematics(angles, this.orignialChainWorldMatrix);
    }

    private setupTargetControl(target: URDFJoint) {
        let color = new Color(0x00ff00);
        const material = new MeshPhongMaterial({
            color,
            opacity: 0,
            transparent: true,
        });

        this.targetSphere = new Mesh(new SphereGeometry(0.01), material);
        let pos = new Vector3();
        pos.setFromMatrixPosition(target.matrixWorld);
        this.targetSphere.position.copy(pos);
        this.scene.add(this.targetSphere);
    }

    step() {
        if (
            !this.realigning &&
            !this.equalsToLastVector(this.targetSphere.position)
        ) {
            this.lastVector = this.targetSphere.position.clone(); // create copy not direct reference
            let position = new Vector3();
            this.endOfChain.getWorldPosition(position);

            let angles: number[] = [];
            this.currentAngles.forEach((a: number) => angles.push(a));

            this.inverseKinematics(this.targetSphere.position, angles);

            this.currentAngles = angles;
            this.setJointsFromList(angles, this.chain);
        }
    }

    private setJointsFromList(angles: number[], joints: URDFJoint[]) {
        if (angles.length === joints.length) {
            for (let i = 0; i < angles.length; i++) {
                if (joints[i]['mimicJoint'] === undefined) {
                    joints[i].setJointValue(angles[i]);
                    joints[i].mimicJoints.forEach(
                        (mimicJoint: URDFMimicJoint) => {
                            mimicJoint.setJointValue(
                                angles[i] * mimicJoint.multiplier.valueOf() +
                                    mimicJoint.offset.valueOf()
                            );
                        }
                    );
                }
            }
        } else {
            console.error(angles.length + ' !=' + joints.length);
            console.error('Joints and angles list should be of same length');
        }
    }

    public realign() {
        this.realigning = true;
        const linkChain = <URDFLink>this.endOfChain.children[0];
        let pos = new Vector3();
        this.targetSphere.position.copy(linkChain.getWorldPosition(pos));
        this.realigning = false;
    }

    private equalsToLastVector(toCompare: Vector3): boolean {
        return (
            !!this.lastVector &&
            this.lastVector.x === toCompare?.x &&
            this.lastVector.y === toCompare?.y &&
            this.lastVector.z === toCompare?.z
        );
    }

    private inverseKinematics(target: Vector3, angles: number[]) {
        if (this.distanceFromTarget(target, angles) < this.distanceThreshold) {
            return;
        }
        let distance = 999;
        for (let i = 0; i < this.chain.length; i++) {
            let gradient = this.partialGradient(target, angles, i);
            angles[i] -= this.learningrate * gradient;
            angles[i] = limit(
                angles[i],
                Number(this.chain[i].limit.lower),
                Number(this.chain[i].limit.upper)
            );
            distance = this.distanceFromTarget(target, angles);
            if (distance < this.distanceThreshold) {
                break;
            }
        }
    }

    private partialGradient(
        target: Vector3,
        angles: number[],
        index: number
    ): number {
        const angle = angles[index];
        const f_x = this.distanceFromTarget(target, angles);
        angles[index] += this.samplingDistance;
        const f_xplusd = this.distanceFromTarget(target, angles);
        const gradient = (f_xplusd - f_x) / this.samplingDistance;
        angles[index] = angle;
        return gradient;
    }

    distanceFromTarget(target: Vector3, angles: number[]): number {
        const point = this.forwardKinematics(
            angles,
            this.orignialChainWorldMatrix
        );
        return point.distanceTo(target);
    }

    private computeRelativeChain() {
        for (let i = 0; i < this.orignialChainWorldMatrix.length; i++) {
            if (i === 0) {
                this.relativeOgChain.push(this.orignialChainWorldMatrix[i]);
            } else {
                this.relativeOgChain.push(
                    this.computeTransformationMatrix(
                        this.orignialChainWorldMatrix[i],
                        this.orignialChainWorldMatrix[i - 1]
                    )
                );
            }
        }
    }

    private forwardKinematics(
        angles: number[],
        jointMatrix: Matrix4[]
    ): Vector3 {
        let previousMatrix = new Matrix4().identity();

        for (let i = 0; i < jointMatrix.length; i++) {
            if (i > 0) {
                let rotZ = angles[i - 1];
                let rotation = this.buildTransformationMatrix(
                    0,
                    0,
                    rotZ,
                    new Vector3(0, 0, 0)
                );
                previousMatrix.multiply(rotation);
                previousMatrix.multiply(this.relativeOgChain[i]);
            } else {
                previousMatrix.multiply(jointMatrix[0]);
            }
        }
        let pos = new Vector3();
        pos.setFromMatrixPosition(previousMatrix);
        return pos;
    }

    //https://en.wikipedia.org/wiki/Rotation_matrix
    buildTransformationMatrix(
        x: number,
        y: number,
        z: number,
        translation: Vector3
    ): Matrix4 {
        let rotX = new Matrix4().identity(); /// configurates 4x4 rotation matrix around x axis
        rotX.elements[5] = Math.cos(x);
        rotX.elements[6] = Math.sin(x);
        rotX.elements[9] = -Math.sin(x);
        rotX.elements[10] = Math.cos(x);
        let rotY = new Matrix4().identity(); // configurates 4x4 matrix around Y axis
        rotY.elements[0] = Math.cos(y);
        rotY.elements[2] = -Math.sin(y);
        rotY.elements[8] = Math.sin(y);
        rotY.elements[10] = Math.cos(y);
        let rotZ = new Matrix4().identity(); // configurates 4x4 matrix around Z axit
        rotZ.elements[0] = Math.cos(z);
        rotZ.elements[4] = -Math.sin(z);
        rotZ.elements[1] = Math.sin(z);
        rotZ.elements[5] = Math.cos(z);
        let res = new Matrix4()
            .identity()
            .multiply(rotX)
            .multiply(rotY)
            .multiply(rotZ); //Builds xyz rotation(rpy)

        res.setPosition(translation); // sets translation Vector in 4x4 matrix

        return res;
    }

    private computeTransformationMatrix(cur: Matrix4, last: Matrix4): Matrix4 {
        let inv = last.clone().invert();
        return new Matrix4().multiplyMatrices(inv, cur);
    }

    /**
     * bottom up creation of chain from startObj parameter through its parent/ancestors,
     * will return in top-down order the array of Joints.
     * @param startObj
     * @returns URDFJoint[] chain of revolute Joints
     */
    private createKinematicChain(startObj: URDFJoint): URDFJoint[] {
        const chain: URDFJoint[] = bottomUpJointChain(startObj, this.scene);
        return chain.reverse();
    }
}

/**
 * Gathering all (revolute) Joints, from bottom up traversion of startObject ancestors.
 * @param startObj
 * @returns URDFJoint[] chain of revolute Joints
 */
export function bottomUpJointChain(
    startObj: URDFJoint,
    top: Scene
): URDFJoint[] {
    const chain: URDFJoint[] = [];
    if (startObj) {
        let part: Object3D = startObj;
        do {
            if (part.type === 'URDFJoint') {
                const currentJoint = <URDFJoint>part;
                if (
                    [
                        JointType.REVOLUTE.toString(),
                        JointType.PRISMATIC.toString(),
                    ].includes(currentJoint.jointType)
                ) {
                    chain.push(currentJoint);
                }
            } else {
                if (part.parent === top) {
                    part = null;
                }
            }
            part = part?.parent;
        } while (part);
    }
    return chain;
}

/*** Fetch all (revolute and prismatic) joints from TOP to BOTTOM, traversing recursively the object tree.
 * @param currentRoot Root to start from
 * @returns Array of URDFJoint
 */
export function getRevoluteAndPrismaticJoints(
    currentRoot: URDFRobot
): URDFJoint[] {
    const jointArr: URDFJoint[] = [];
    if (!!currentRoot) {
        currentRoot.traverseVisible((child: Object3D) => {
            if (
                !!child &&
                ['URDFJoint', 'URDFMimicJoint'].includes(child.type)
            ) {
                const currentJoint = <URDFJoint>child;
                if (
                    [
                        JointType.REVOLUTE.toString(),
                        JointType.PRISMATIC.toString(),
                    ].includes(currentJoint.jointType)
                ) {
                    jointArr.push(currentJoint);
                }
            }
        });
    }
    return jointArr;
}
