/*
 * Decompiled with CFR 0.152.
 */
package me.modmuss50.optifabric.patcher;

import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Runnables;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import me.modmuss50.optifabric.patcher.Lambda;
import me.modmuss50.optifabric.patcher.MethodComparison;
import me.modmuss50.optifabric.shadow.tinyremapper.IMappingProvider;
import me.modmuss50.optifabric.util.ASMUtils;
import net.fabricmc.loader.api.FabricLoader;
import org.apache.commons.lang3.tuple.Pair;
import org.objectweb.asm.Handle;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.InvokeDynamicInsnNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;

public class LambdaRebuilder
implements Closeable,
IMappingProvider {
    private static final boolean ALLOW_VAGUE_EQUIVALENCE = !Boolean.getBoolean("optifabric.exactOnly");
    private final JarFile minecraftClientFile;
    private final Map<IMappingProvider.Member, String> fixes = new HashMap<IMappingProvider.Member, String>();
    protected final Map<IMappingProvider.Member, Pair<String, String>> fuzzes = ALLOW_VAGUE_EQUIVALENCE ? new HashMap() : Collections.emptyMap();

    public static void main(String ... args) throws IOException {
        IMappingProvider.Member lambda;
        File optifine;
        if (args == null || args.length != 2) {
            System.out.println("Usage: <vanilla_class> <optifine_class>");
            return;
        }
        File vanilla = new File(args[0]);
        if (!vanilla.exists() || !vanilla.isFile()) {
            System.err.println("Invalid vanilla class: " + args[0]);
            System.exit(1);
        }
        if (!(optifine = new File(args[1])).exists() || !optifine.isFile()) {
            System.err.println("Invalid OptiFine class: " + args[0]);
            System.exit(1);
        }
        ClassNode minecraft = ASMUtils.readClass(vanilla);
        ClassNode patched = ASMUtils.readClass(optifine);
        LambdaRebuilder rebuilder = new LambdaRebuilder(){

            @Override
            protected final String remapName(String owner, String name, String desc) {
                return name;
            }
        };
        int unsolved = super.findLambdas(minecraft.name, minecraft.methods, patched.methods);
        rebuilder.close();
        int total = rebuilder.fixes.size() + rebuilder.fuzzes.size();
        System.out.printf(unsolved == 0 ? "Fully matched up %d lambdas:%n" : "Partially matched %d/%d lambdas%n", total, total + unsolved);
        for (Map.Entry<IMappingProvider.Member, String> entry : rebuilder.fixes.entrySet()) {
            lambda = entry.getKey();
            System.out.printf("\t%s#%s%s => %s%s%n", lambda.owner, lambda.name, lambda.desc, entry.getValue(), lambda.desc);
        }
        for (Map.Entry<IMappingProvider.Member, String> entry : rebuilder.fuzzes.entrySet()) {
            lambda = entry.getKey();
            Pair remap = (Pair)entry.getValue();
            System.out.printf("\t%s#%s%s => %s%s%n", lambda.owner, lambda.name, lambda.desc, remap.getLeft(), remap.getRight());
        }
    }

    private LambdaRebuilder() {
        this.minecraftClientFile = null;
    }

    public LambdaRebuilder(File minecraftClientFile) throws IOException {
        this.minecraftClientFile = new JarFile(minecraftClientFile);
    }

    public void findLambdas(ClassNode patched) throws IOException {
        JarEntry entry = this.minecraftClientFile.getJarEntry(patched.name.concat(".class"));
        if (entry == null) {
            throw new IllegalArgumentException(patched.name.concat(" not present in vanilla"));
        }
        ClassNode minecraftClass = ASMUtils.readClass(this.minecraftClientFile, entry);
        this.findLambdas(minecraftClass, patched);
        if (!this.fuzzes.isEmpty()) {
            HashMap<String, String> toCheck = new HashMap<String, String>();
            HashMap<String, IMappingProvider.Member> checkedLambdas = new HashMap<String, IMappingProvider.Member>();
            for (Map.Entry<IMappingProvider.Member, Pair<String, String>> fuzz : this.fuzzes.entrySet()) {
                IMappingProvider.Member lambda = fuzz.getKey();
                Pair<String, String> remap = fuzz.getValue();
                toCheck.put(lambda.name.concat(lambda.desc), ((String)remap.getLeft()).concat((String)remap.getRight()));
                checkedLambdas.put(lambda.name.concat(lambda.desc), lambda);
            }
            this.fix(toCheck, checkedLambdas, minecraftClass, patched);
        }
    }

    protected int findLambdas(ClassNode original, ClassNode patched) {
        if (!original.name.equals(patched.name)) {
            throw new IllegalArgumentException("Patched class (" + patched.name + ") is not the same as the original (" + original.name + ')');
        }
        return this.findLambdas(original.name, original.methods, patched.methods);
    }

    private int findLambdas(String className, List<MethodNode> original, List<MethodNode> patched) {
        MethodComparison method2;
        int i;
        Collector methodMapper = Collectors.toMap(method -> method.name.concat(method.desc), Function.identity());
        ArrayList<MethodComparison> commonMethods = new ArrayList<MethodComparison>();
        ArrayList<MethodNode> lostMethods = new ArrayList<MethodNode>();
        ArrayList<MethodNode> gainedMethods = new ArrayList<MethodNode>();
        Map originalMethods = original.stream().collect(methodMapper);
        Map patchedMethods = patched.stream().collect(methodMapper);
        for (String methodName : Sets.union(originalMethods.keySet(), patchedMethods.keySet())) {
            MethodNode originalMethod = (MethodNode)originalMethods.get(methodName);
            MethodNode patchedMethod = (MethodNode)patchedMethods.get(methodName);
            if (originalMethod != null) {
                if (patchedMethod != null) {
                    commonMethods.add(new MethodComparison(originalMethod, patchedMethod));
                    continue;
                }
                lostMethods.add(originalMethod);
                continue;
            }
            if (patchedMethod != null) {
                gainedMethods.add(patchedMethod);
                continue;
            }
            throw new IllegalStateException("Unable to find " + methodName + " in either " + className + " versions");
        }
        commonMethods.sort(Comparator.comparingInt(method -> {
            if (!"<clinit>".equals(method.node.name)) {
                return patched.indexOf(method.node);
            }
            if ("com/mojang/blaze3d/platform/GLX".equals(className)) {
                return patched.size();
            }
            return -1;
        }));
        lostMethods.sort(Comparator.comparingInt(original::indexOf));
        gainedMethods.sort(Comparator.comparingInt(patched::indexOf));
        if (commonMethods.stream().noneMatch(method -> !method.equal && method.hasLambdas()) || lostMethods.isEmpty() || gainedMethods.isEmpty()) {
            return 0;
        }
        Map<String, MethodNode> possibleLambdas = gainedMethods.stream().filter(method -> (method.access & 0x1000) != 0 && method.name.startsWith("lambda$")).collect(methodMapper);
        if (possibleLambdas.isEmpty()) {
            return 0;
        }
        Map<String, MethodNode> nameToLosses = lostMethods.stream().collect(methodMapper);
        for (i = 0; i < commonMethods.size(); ++i) {
            method2 = (MethodComparison)commonMethods.get(i);
            if (!method2.effectivelyEqual) continue;
            this.resolveCloseMethod(className, commonMethods, lostMethods, gainedMethods, method2, nameToLosses, possibleLambdas);
        }
        for (i = 0; i < commonMethods.size(); ++i) {
            List<Lambda> patchedLambdas;
            List<Lambda> originalLambdas;
            block13: {
                method2 = (MethodComparison)commonMethods.get(i);
                if (method2.effectivelyEqual) continue;
                originalLambdas = method2.getOriginalLambads();
                patchedLambdas = method2.getPatchedLambads();
                if (originalLambdas.size() == patchedLambdas.size()) {
                    Iterator<Lambda> itOriginal = originalLambdas.iterator();
                    Iterator<Lambda> itPatched = patchedLambdas.iterator();
                    while (itOriginal.hasNext() && itPatched.hasNext()) {
                        int patchedSplit;
                        int originalSplit;
                        Lambda originalLambda = itOriginal.next();
                        Lambda patchedLambda = itPatched.next();
                        if (Objects.equals(originalLambda.method, patchedLambda.method) || (originalSplit = originalLambda.method.indexOf(40)) == (patchedSplit = patchedLambda.method.indexOf(40)) && originalLambda.method.regionMatches(0, patchedLambda.method, 0, originalSplit) && Type.getReturnType((String)originalLambda.method).equals((Object)Type.getReturnType((String)patchedLambda.method)) && Objects.equals(originalLambda.owner, patchedLambda.owner)) {
                            continue;
                        }
                        break block13;
                    }
                    this.pairUp(className, commonMethods, lostMethods, gainedMethods, originalLambdas, patchedLambdas, nameToLosses, possibleLambdas, () -> {
                        for (int j = commonMethods.size() - 1; j < commonMethods.size(); ++j) {
                            MethodComparison innerMethod = (MethodComparison)commonMethods.get(j);
                            if (!innerMethod.effectivelyEqual) continue;
                            this.resolveCloseMethod(className, commonMethods, lostMethods, gainedMethods, innerMethod, nameToLosses, possibleLambdas);
                        }
                    });
                    continue;
                }
            }
            Collector<Lambda, ?, Map<String, Map<String, List<Lambda>>>> lambdaCategorisation = Collectors.groupingBy(lambda -> lambda.desc, Collectors.groupingBy(lambda -> lambda.method));
            Map<String, Map<String, List<Lambda>>> descToOriginalLambda = originalLambdas.stream().collect(lambdaCategorisation);
            Map<String, Map<String, List<Lambda>>> descToPatchedLambda = patchedLambdas.stream().collect(lambdaCategorisation);
            Sets.SetView commonDescs = Sets.intersection(descToOriginalLambda.keySet(), descToPatchedLambda.keySet());
            if (commonDescs.isEmpty()) continue;
            int fixedLambdas = 0;
            for (String desc : commonDescs) {
                Map<String, List<Lambda>> typeToOriginalLambda = descToOriginalLambda.get(desc);
                Map<String, List<Lambda>> typeToPatchedLambda = descToPatchedLambda.get(desc);
                for (String type : Sets.intersection(typeToOriginalLambda.keySet(), typeToPatchedLambda.keySet())) {
                    List<Lambda> matchedOriginalLambdas = typeToOriginalLambda.get(type);
                    List<Lambda> matchedPatchedLambdas = typeToPatchedLambda.get(type);
                    if (matchedOriginalLambdas.size() != matchedPatchedLambdas.size()) continue;
                    fixedLambdas += matchedOriginalLambdas.size();
                    this.pairUp(className, commonMethods, lostMethods, gainedMethods, matchedOriginalLambdas, matchedPatchedLambdas, nameToLosses, possibleLambdas, () -> {
                        for (int j = commonMethods.size() - 1; j < commonMethods.size(); ++j) {
                            MethodComparison innerMethod = (MethodComparison)commonMethods.get(j);
                            if (!innerMethod.effectivelyEqual) continue;
                            this.resolveCloseMethod(className, commonMethods, lostMethods, gainedMethods, innerMethod, nameToLosses, possibleLambdas);
                        }
                    });
                }
            }
            if (fixedLambdas != originalLambdas.size()) continue;
            return 0;
        }
        return possibleLambdas.size();
    }

    private void resolveCloseMethod(String className, List<MethodComparison> commonMethods, List<MethodNode> lostMethods, List<MethodNode> gainedMethods, MethodComparison method, Map<String, MethodNode> nameToLosses, Map<String, MethodNode> possibleLambdas) {
        assert (method.effectivelyEqual);
        if (!method.equal) {
            if (method.getOriginalLambads().size() != method.getPatchedLambads().size()) {
                throw new IllegalStateException("Bytecode in " + className + '#' + method.node.name + method.node.desc + " appeared unchanged but lambda count changed?");
            }
            this.pairUp(className, commonMethods, lostMethods, gainedMethods, method.getOriginalLambads(), method.getPatchedLambads(), nameToLosses, possibleLambdas, Runnables.doNothing());
        }
    }

    private void pairUp(String className, List<MethodComparison> commonMethods, List<MethodNode> lostMethods, List<MethodNode> gainedMethods, List<Lambda> originalLambdas, List<Lambda> patchedLambdas, Map<String, MethodNode> nameToLosses, Map<String, MethodNode> possibleLambdas, Runnable onPair) {
        assert (originalLambdas.size() == patchedLambdas.size());
        Iterator<Lambda> itOriginal = originalLambdas.iterator();
        Iterator<Lambda> itPatched = patchedLambdas.iterator();
        while (itOriginal.hasNext() && itPatched.hasNext()) {
            Lambda lost = itOriginal.next();
            Lambda gained = itPatched.next();
            if (!className.equals(lost.owner)) {
                return;
            }
            assert (className.equals(gained.owner));
            MethodNode lostMethod = nameToLosses.remove(lost.getName());
            MethodNode gainedMethod = possibleLambdas.remove(gained.getName());
            if (lostMethod == null) {
                if (gainedMethod == null) {
                    assert (Objects.equals(lost.getFullName(), gained.getFullName()));
                    continue;
                }
                throw new IllegalStateException("Couldn't find original method for lambda: " + lost.getFullName());
            }
            if (gainedMethod == null) {
                throw new IllegalStateException("Couldn't find patched method for lambda: " + gained.getFullName());
            }
            if (!this.addFix(className, commonMethods, gainedMethod, lostMethod)) continue;
            lostMethods.remove(lostMethod);
            gainedMethods.remove(gainedMethod);
            onPair.run();
        }
    }

    private boolean addFix(String className, List<MethodComparison> commonMethods, MethodNode from, MethodNode to) {
        boolean vague = !from.desc.equals(to.desc);
        if (vague && !ALLOW_VAGUE_EQUIVALENCE) {
            System.err.println("Description changed remapping lambda handle: " + className + '#' + from.name + from.desc + " => " + className + '#' + to.name + to.desc);
            return false;
        }
        if (vague) {
            System.out.printf("Fuzzing %s#%s%s as %s%s%n", className, from.name, from.desc, to.name, to.desc);
            this.fuzzes.put(new IMappingProvider.Member(className, from.name, from.desc), (Pair<String, String>)Pair.of((Object)to.name, (Object)to.desc));
        } else {
            this.fixes.put(new IMappingProvider.Member(className, from.name, from.desc), this.remapName(className, to.name, to.desc));
        }
        commonMethods.add(new MethodComparison(to, from, true));
        return true;
    }

    protected String remapName(String owner, String name, String desc) {
        return FabricLoader.getInstance().getMappingResolver().mapMethodName("official", owner.replace('/', '.'), name, desc);
    }

    public void load(IMappingProvider.MappingAcceptor out) {
        this.fixes.forEach((arg_0, arg_1) -> ((IMappingProvider.MappingAcceptor)out).acceptMethod(arg_0, arg_1));
        for (Map.Entry<IMappingProvider.Member, Pair<String, String>> entry : this.fuzzes.entrySet()) {
            IMappingProvider.Member lambda = entry.getKey();
            Pair<String, String> remap = entry.getValue();
            out.acceptMethod(lambda, this.remapName(lambda.owner, (String)remap.getLeft(), (String)remap.getRight()));
        }
    }

    private void fix(Map<String, String> toCheck, Map<String, IMappingProvider.Member> checkedLambdas, ClassNode minecraft, ClassNode optifine) {
        Object2IntOpenHashMap memberToAccess = new Object2IntOpenHashMap(minecraft.methods.size());
        memberToAccess.defaultReturnValue(-1);
        for (MethodNode method : minecraft.methods) {
            String key = method.name.concat(method.desc);
            memberToAccess.put((Object)key, method.access);
        }
        HashMap<String, String> staticFlip = new HashMap<String, String>();
        for (MethodNode method : optifine.methods) {
            String key = method.name.concat(method.desc);
            String remap = toCheck.get(key);
            if (remap == null) continue;
            int access = memberToAccess.getInt((Object)remap);
            if (access == -1) {
                throw new IllegalStateException("Unable to find vanilla method " + minecraft.name + '#' + remap);
            }
            boolean shouldBeStatic = Modifier.isStatic(access);
            if (Modifier.isStatic(method.access) == shouldBeStatic) continue;
            if (!shouldBeStatic) {
                Type[] args;
                if (Modifier.isPrivate(method.access) && (args = Type.getArgumentTypes((String)method.desc)).length > 0 && optifine.name.equals(args[0].getInternalName())) {
                    staticFlip.put(method.name.concat(method.desc), Type.getMethodDescriptor((Type)Type.getReturnType((String)method.desc), (Type[])Arrays.copyOfRange(args, 1, args.length)));
                    continue;
                }
                throw new UnsupportedOperationException("Method has become static: " + optifine.name + '#' + key);
            }
            if (Modifier.isPrivate(method.access)) {
                staticFlip.put(key, "(L" + optifine.name + ';' + method.desc.substring(1));
                continue;
            }
            throw new UnsupportedOperationException("Method is no longer static: " + optifine.name + '#' + key);
        }
        if (!staticFlip.isEmpty()) {
            for (MethodNode method : optifine.methods) {
                String newDesc = (String)staticFlip.get(method.name.concat(method.desc));
                if (newDesc != null) {
                    method.access ^= 8;
                    Objects.requireNonNull(checkedLambdas.get((Object)method.name.concat((String)method.desc)), (String)new StringBuilder((String)"Failed to find lambda ").append((String)optifine.name).append((char)'#').append((String)method.name).append((String)method.desc).toString()).desc = newDesc;
                    method.desc = newDesc;
                }
                for (AbstractInsnNode insn : method.instructions) {
                    switch (insn.getType()) {
                        case 5: {
                            MethodInsnNode minsn = (MethodInsnNode)insn;
                            if (!optifine.name.equals(minsn.owner) || (newDesc = (String)staticFlip.get(minsn.name.concat(minsn.desc))) == null) break;
                            MethodInsnNode methodInsnNode = minsn;
                            methodInsnNode.setOpcode(methodInsnNode.getOpcode() == 184 ? 182 : 184);
                            minsn.desc = newDesc;
                            break;
                        }
                        case 6: {
                            Handle lambda;
                            InvokeDynamicInsnNode dinsn = (InvokeDynamicInsnNode)insn;
                            if (!MethodComparison.isJavaLambdaMetafactory(dinsn.bsm) || !optifine.name.equals((lambda = (Handle)dinsn.bsmArgs[1]).getOwner()) || (newDesc = (String)staticFlip.get(lambda.getName().concat(lambda.getDesc()))) == null) break;
                            dinsn.bsmArgs[1] = new Handle(lambda.getTag() == 6 ? 5 : 6, lambda.getOwner(), lambda.getName(), newDesc, lambda.isInterface());
                        }
                    }
                }
            }
        }
    }

    @Override
    public void close() throws IOException {
        if (this.minecraftClientFile != null) {
            this.minecraftClientFile.close();
        }
    }
}

