/*
 * Decompiled with CFR 0.152.
 */
package teamroots.embers.util;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.function.DoubleFunction;
import net.minecraft.util.math.MathHelper;
import net.minecraft.util.math.Vec3d;

public class Spline {
    DoubleFunction<Vec3d> function;
    ArrayList<Point> points = new ArrayList();
    HashMap<Double, Integer> pointMap = new HashMap();
    double pointInterval;
    double totalArcLength;

    public Spline(DoubleFunction<Vec3d> function) {
        this.function = function;
    }

    public double getTotalArcLength() {
        return this.totalArcLength;
    }

    public Vec3d getPoint(double arcLength) {
        return this.function.apply(this.getIndex(arcLength));
    }

    public double getIndex(double arcLength) {
        arcLength = MathHelper.clamp((double)arcLength, (double)0.0, (double)this.totalArcLength);
        double index = this.calculateIndex(arcLength);
        int closeIndex = this.pointMap.get(index);
        double dist = this.points.get((int)closeIndex).distance;
        if (arcLength == dist) {
            return this.points.get((int)closeIndex).index;
        }
        if (arcLength < dist) {
            return this.getIndexLeft(closeIndex, arcLength);
        }
        return this.getIndexRight(closeIndex, arcLength);
    }

    private double getIndexLeft(int rightIndex, double arcLength) {
        int leftIndex = rightIndex - 1;
        double dist = this.points.get((int)leftIndex).distance;
        if (arcLength == dist) {
            return this.points.get((int)leftIndex).index;
        }
        if (arcLength < dist) {
            return this.getIndexLeft(leftIndex, arcLength);
        }
        return this.interpolateIndex(leftIndex, rightIndex, arcLength);
    }

    private double getIndexRight(int leftIndex, double arcLength) {
        int rightIndex = leftIndex + 1;
        double dist = this.points.get((int)rightIndex).distance;
        if (arcLength == dist) {
            return this.points.get((int)rightIndex).index;
        }
        if (arcLength < dist) {
            return this.interpolateIndex(leftIndex, rightIndex, arcLength);
        }
        return this.getIndexRight(rightIndex, arcLength);
    }

    private double interpolateIndex(int leftIndex, int rightIndex, double arcLength) {
        Point leftPoint = this.points.get(leftIndex);
        Point rightPoint = this.points.get(rightIndex);
        double midpoint = (arcLength - leftPoint.distance) / (rightPoint.distance - leftPoint.distance);
        return MathHelper.clampedLerp((double)leftPoint.index, (double)rightPoint.index, (double)midpoint);
    }

    public double calculateIndex(double arcLength) {
        return Math.floor(arcLength / this.pointInterval) * this.pointInterval;
    }

    public void cachePoints(int minSegments, double maxDist, double cacheInterval) {
        int i;
        int i2;
        this.pointInterval = cacheInterval;
        for (i2 = 0; i2 < minSegments; ++i2) {
            double index = (double)i2 / (double)(minSegments - 1);
            this.points.add(new Point(index, this.function.apply(index)));
        }
        i2 = 0;
        while (i2 < this.points.size() - 1) {
            int e = i2 + 1;
            Point pi = this.points.get(i2);
            Point pe = this.points.get(e);
            double midpoint = (pi.index + pe.index) / 2.0;
            if (midpoint != pi.index && midpoint != pe.index && pi.point.squareDistanceTo(pe.point) > maxDist * maxDist) {
                this.points.add(e, new Point(midpoint, this.function.apply(midpoint)));
                continue;
            }
            ++i2;
        }
        double totalDistance = 0.0;
        for (i = 0; i < this.points.size(); ++i) {
            Point pi = this.points.get(i);
            if (i < this.points.size() - 1) {
                int e = i + 1;
                Point pe = this.points.get(e);
                double distance = pi.point.distanceTo(pe.point);
                pe.distance = totalDistance += distance;
            }
            this.pointMap.put(this.calculateIndex(pi.distance), i);
        }
        this.totalArcLength = totalDistance;
        i = 0;
        while ((double)i * this.pointInterval < this.totalArcLength) {
            double index = (double)i * this.pointInterval;
            double lastIndex = (double)(i - 1) * this.pointInterval;
            this.pointMap.put(index, this.pointMap.getOrDefault(index, this.pointMap.get(lastIndex)));
            ++i;
        }
    }

    static class Point {
        public double index;
        public Vec3d point;
        public double distance;

        public Point(double index, Vec3d point) {
            this.index = index;
            this.point = point;
        }
    }
}

