/*
 * Decompiled with CFR 0.152.
 */
package no.uib.cipr.matrix.sparse;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import no.uib.cipr.matrix.AbstractMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.Vector;
import no.uib.cipr.matrix.VectorEntry;
import no.uib.cipr.matrix.sparse.Arrays;
import no.uib.cipr.matrix.sparse.FlexCompRowMatrix;
import no.uib.cipr.matrix.sparse.Preconditioner;
import no.uib.cipr.matrix.sparse.SparseVector;

public class ILUT
implements Preconditioner {
    private final FlexCompRowMatrix LU;
    private Matrix L;
    private Matrix U;
    private final Vector y;
    private final double tau;
    private final int[] diagind;
    private final List<IntDoubleEntry> lower;
    private final List<IntDoubleEntry> upper;
    private final int p;

    public ILUT(FlexCompRowMatrix LU, double tau, int p) {
        if (!LU.isSquare()) {
            throw new IllegalArgumentException("ILU only applies to square matrices");
        }
        this.LU = LU;
        this.tau = tau;
        this.p = p;
        int n = LU.numRows();
        this.lower = new ArrayList<IntDoubleEntry>(n);
        this.upper = new ArrayList<IntDoubleEntry>(n);
        this.y = new DenseVector(n);
        this.diagind = new int[n];
    }

    public ILUT(FlexCompRowMatrix LU) {
        this(LU, 1.0E-6, 25);
    }

    @Override
    public Vector apply(Vector b, Vector x) {
        this.L.solve(b, this.y);
        return this.U.solve(this.y, x);
    }

    @Override
    public Vector transApply(Vector b, Vector x) {
        this.U.transSolve(b, this.y);
        return this.L.transSolve(this.y, x);
    }

    @Override
    public void setMatrix(Matrix A) {
        this.LU.set(A);
        this.LU.compact();
        this.factor();
    }

    private void factor() {
        int n = this.LU.numRows();
        double[] LUi = new double[n];
        int k = 0;
        while (k < n) {
            SparseVector row = this.LU.getRow(k);
            this.diagind[k] = this.findDiagonalIndex(row, k);
            if (this.diagind[k] < 0) {
                throw new RuntimeException("Missing diagonal entry on row " + (k + 1));
            }
            ++k;
        }
        int i = 1;
        while (i < n) {
            SparseVector rowi = this.LU.getRow(i);
            double taui = rowi.norm(Vector.Norm.Two) * this.tau;
            this.scatter(rowi, LUi);
            int k2 = 0;
            while (k2 < i) {
                SparseVector rowk = this.LU.getRow(k2);
                int[] rowIndex = rowk.getIndex();
                int rowUsed = rowk.getUsed();
                double[] rowData = rowk.getData();
                if (rowData[this.diagind[k2]] == 0.0) {
                    throw new RuntimeException("Zero diagonal entry on row " + (k2 + 1) + " during ILU process");
                }
                double LUik = LUi[k2] / rowData[this.diagind[k2]];
                if (!(Math.abs(LUik) <= taui)) {
                    int j = this.diagind[k2] + 1;
                    while (j < rowUsed) {
                        int n2 = rowIndex[j];
                        LUi[n2] = LUi[n2] - LUik * rowData[j];
                        ++j;
                    }
                    LUi[k2] = LUik;
                }
                ++k2;
            }
            this.gather(LUi, rowi, taui, i);
            int diagIndex = this.diagind[i];
            int[] rowiIndices = rowi.getIndex();
            if (diagIndex >= rowiIndices.length || rowiIndices[diagIndex] != i) {
                this.diagind[i] = this.findDiagonalIndex(rowi, i);
                if (this.diagind[i] < 0) {
                    throw new RuntimeException("Missing diagonal entry on row " + (i + 1) + " during ILU process");
                }
            }
            ++i;
        }
        this.L = new UnitLowerFlexCompRowMatrix(this.LU, this.diagind);
        this.U = new UpperFlexCompRowMatrix(this.LU, this.diagind);
    }

    private int findDiagonalIndex(SparseVector v, int k) {
        return Arrays.binarySearch(v.getIndex(), k, 0, v.getUsed());
    }

    private void scatter(SparseVector v, double[] z) {
        int[] index = v.getIndex();
        int used = v.getUsed();
        double[] data = v.getData();
        java.util.Arrays.fill(z, 0.0);
        int i = 0;
        while (i < used) {
            z[index[i]] = data[i];
            ++i;
        }
    }

    private void gather(double[] z, SparseVector v, double taui, int d) {
        IntDoubleEntry e;
        int nl = 0;
        int nu = 0;
        for (VectorEntry e2 : v) {
            if (e2.index() < d) {
                ++nl;
                continue;
            }
            if (e2.index() <= d) continue;
            ++nu;
        }
        v.zero();
        this.lower.clear();
        int i = 0;
        while (i < d) {
            if (Math.abs(z[i]) > taui) {
                this.lower.add(new IntDoubleEntry(i, z[i]));
            }
            ++i;
        }
        this.upper.clear();
        i = d + 1;
        while (i < z.length) {
            if (Math.abs(z[i]) > taui) {
                this.upper.add(new IntDoubleEntry(i, z[i]));
            }
            ++i;
        }
        Collections.sort(this.lower);
        Collections.sort(this.upper);
        v.set(d, z[d]);
        i = 0;
        while (i < Math.min(nl + this.p, this.lower.size())) {
            e = this.lower.get(i);
            v.set(e.index, e.value);
            ++i;
        }
        i = 0;
        while (i < Math.min(nu + this.p, this.upper.size())) {
            e = this.upper.get(i);
            v.set(e.index, e.value);
            ++i;
        }
    }

    private static class IntDoubleEntry
    implements Comparable<IntDoubleEntry> {
        public int index;
        public double value;

        public IntDoubleEntry(int index, double value) {
            this.index = index;
            this.value = value;
        }

        @Override
        public int compareTo(IntDoubleEntry o) {
            if (Math.abs(this.value) < Math.abs(o.value)) {
                return 1;
            }
            if (Math.abs(this.value) == Math.abs(o.value)) {
                return 0;
            }
            return -1;
        }

        public String toString() {
            return "(" + this.index + "=" + this.value + ")";
        }
    }

    private static class UnitLowerFlexCompRowMatrix
    extends AbstractMatrix {
        private final FlexCompRowMatrix LU;
        private final int[] diagind;

        public UnitLowerFlexCompRowMatrix(FlexCompRowMatrix LU, int[] diagind) {
            super(LU);
            this.LU = LU;
            this.diagind = diagind;
        }

        @Override
        public Vector solve(Vector b, Vector x) {
            if (!(b instanceof DenseVector) || !(x instanceof DenseVector)) {
                return super.solve(b, x);
            }
            double[] bd = ((DenseVector)b).getData();
            double[] xd = ((DenseVector)x).getData();
            int i = 0;
            while (i < this.numRows) {
                SparseVector row = this.LU.getRow(i);
                int[] index = row.getIndex();
                double[] data = row.getData();
                double sum = 0.0;
                int j = 0;
                while (j < this.diagind[i]) {
                    sum += data[j] * xd[index[j]];
                    ++j;
                }
                xd[i] = bd[i] - sum;
                ++i;
            }
            return x;
        }

        @Override
        public Vector transSolve(Vector b, Vector x) {
            if (!(x instanceof DenseVector)) {
                return super.transSolve(b, x);
            }
            x.set(b);
            double[] xd = ((DenseVector)x).getData();
            int i = this.numRows - 1;
            while (i >= 0) {
                SparseVector row = this.LU.getRow(i);
                int[] index = row.getIndex();
                double[] data = row.getData();
                int j = 0;
                while (j < this.diagind[i]) {
                    int n = index[j];
                    xd[n] = xd[n] - data[j] * xd[i];
                    ++j;
                }
                --i;
            }
            return x;
        }
    }

    private static class UpperFlexCompRowMatrix
    extends AbstractMatrix {
        private final FlexCompRowMatrix LU;
        private final int[] diagind;

        public UpperFlexCompRowMatrix(FlexCompRowMatrix LU, int[] diagind) {
            super(LU);
            this.LU = LU;
            this.diagind = diagind;
        }

        @Override
        public Vector solve(Vector b, Vector x) {
            if (!(b instanceof DenseVector) || !(x instanceof DenseVector)) {
                return super.solve(b, x);
            }
            double[] bd = ((DenseVector)b).getData();
            double[] xd = ((DenseVector)x).getData();
            int i = this.numRows - 1;
            while (i >= 0) {
                SparseVector row = this.LU.getRow(i);
                int[] index = row.getIndex();
                int used = row.getUsed();
                double[] data = row.getData();
                double sum = 0.0;
                int j = this.diagind[i] + 1;
                while (j < used) {
                    sum += data[j] * xd[index[j]];
                    ++j;
                }
                xd[i] = (bd[i] - sum) / data[this.diagind[i]];
                --i;
            }
            return x;
        }

        @Override
        public Vector transSolve(Vector b, Vector x) {
            if (!(x instanceof DenseVector)) {
                return super.transSolve(b, x);
            }
            x.set(b);
            double[] xd = ((DenseVector)x).getData();
            int i = 0;
            while (i < this.numRows) {
                SparseVector row = this.LU.getRow(i);
                int[] index = row.getIndex();
                int used = row.getUsed();
                double[] data = row.getData();
                int n = i;
                xd[n] = xd[n] / data[this.diagind[i]];
                int j = this.diagind[i] + 1;
                while (j < used) {
                    int n2 = index[j];
                    xd[n2] = xd[n2] - data[j] * xd[i];
                    ++j;
                }
                ++i;
            }
            return x;
        }
    }
}

