package amie.data.eval;

import amie.data.FactDatabase;
import amie.query.AMIEreader;
import amie.query.Query;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import javatools.datatypes.ByteString;
import javatools.datatypes.IntHashMap;
import javatools.datatypes.Triple;

/* loaded from: input_file:amie/data/eval/PredictionsSampler.class */
public class PredictionsSampler {
    private int sampleSize;
    private FactDatabase source;

    public PredictionsSampler(FactDatabase factDatabase) {
        this.source = factDatabase;
        this.sampleSize = 30;
    }

    public PredictionsSampler(FactDatabase factDatabase, int i) {
        this.source = factDatabase;
        this.sampleSize = i;
    }

    public int getSampleSize() {
        return this.sampleSize;
    }

    public void setNumberOfPredictions(int i) {
        this.sampleSize = i;
    }

    public Object generateBodyBindings(Query query) {
        return FactDatabase.numVariables(query.getHead()) == 1 ? generateBindingsForSingleVariable(query) : generateBindingsForTwoVariables(query);
    }

    public Object generatePredictions(Query query) {
        return FactDatabase.numVariables(query.getHead()) == 1 ? predictBindingsForSingleVariable(query) : predictBindingsForTwoVariables(query);
    }

    public void runMode1(Collection<Query> collection) {
        HashMap hashMap = new HashMap();
        for (Query query : collection) {
            printPredictions(query, samplePredictions(generatePredictions(query), query, hashMap));
        }
    }

    private Object generateBindingsForTwoVariables(Query query) {
        ByteString[] head = query.getHead();
        ByteString[] byteStringArr = {query.getFunctionalVariable(), head[FactDatabase.secondVariablePos(head)]};
        if (byteStringArr[1].equals(byteStringArr[0])) {
            byteStringArr[1] = head[FactDatabase.firstVariablePos(head)];
        }
        return this.source.selectDistinct(byteStringArr[0], byteStringArr[1], query.getAntecedent());
    }

    private Object generateBindingsForSingleVariable(Query query) {
        return this.source.selectDistinct(query.getFunctionalVariable(), query.getAntecedent());
    }

    private Map<ByteString, IntHashMap<ByteString>> predictBindingsForTwoVariables(Query query) {
        ByteString[] head = query.getHead();
        ByteString[] byteStringArr = {query.getFunctionalVariable(), head[FactDatabase.secondVariablePos(head)]};
        if (byteStringArr[1].equals(byteStringArr[0])) {
            byteStringArr[1] = head[FactDatabase.firstVariablePos(head)];
        }
        return this.source.difference(byteStringArr[0], byteStringArr[1], query.getAntecedent(), query.getTriples());
    }

    private Set<ByteString> predictBindingsForSingleVariable(Query query) {
        return this.source.difference(query.getFunctionalVariable(), query.getAntecedent(), query.getTriples());
    }

    public void runMode2(Collection<Query> collection) {
        for (Query query : collection) {
            printPredictions(query, samplePredictions(generatePredictions(query), query));
        }
    }

    private Collection<Triple<ByteString, ByteString, ByteString>> samplePredictions(Object obj, Query query) {
        int numVariables = FactDatabase.numVariables(query.getHead());
        if (numVariables == 2) {
            return samplePredictionsTwoVariables((Map) obj, query);
        }
        if (numVariables == 1) {
            return samplePredictionsOneVariable((Set) obj, query);
        }
        return null;
    }

    private Collection<Triple<ByteString, ByteString, ByteString>> samplePredictionsOneVariable(Set<ByteString> set, Query query) {
        return null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v16, types: [T, F, java.lang.Object, javatools.datatypes.ByteString] */
    /* JADX WARN: Type inference failed for: r0v25, types: [T, F, java.lang.Object, javatools.datatypes.ByteString] */
    private Collection<Triple<ByteString, ByteString, ByteString>> samplePredictionsTwoVariables(Map<ByteString, IntHashMap<ByteString>> map, Query query) {
        Set<ByteString> keySet = map.keySet();
        S s = query.getHead()[1];
        int functionalVariablePosition = query.getFunctionalVariablePosition();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (ByteString byteString : keySet) {
            Iterator<ByteString> it = map.get(byteString).iterator();
            while (it.hasNext()) {
                ByteString next = it.next();
                Triple triple = new Triple(null, null, null);
                if (!byteString.equals(next)) {
                    if (functionalVariablePosition == 0) {
                        triple.first = byteString;
                        triple.third = next;
                    } else {
                        triple.first = next;
                        triple.third = byteString;
                    }
                    triple.second = s;
                    linkedHashSet.add(triple);
                }
            }
        }
        return sample(linkedHashSet, this.sampleSize);
    }

    private void printPredictions(Query query, Collection<Triple<ByteString, ByteString, ByteString>> collection) {
        for (Triple<ByteString, ByteString, ByteString> triple : collection) {
            System.out.println(String.valueOf(query.getRuleString()) + "\t" + ((Object) triple.first) + "\t" + ((Object) triple.second) + "\t" + ((Object) triple.third));
        }
    }

    private Collection<Triple<ByteString, ByteString, ByteString>> samplePredictions(Object obj, Query query, Map<ByteString, Map<ByteString, Set<ByteString>>> map) {
        int numVariables = FactDatabase.numVariables(query.getHead());
        if (numVariables == 2) {
            return samplePredictionsTwoVariables((Map) obj, query, map);
        }
        if (numVariables == 1) {
            return samplePredictionsOneVariable((Set) obj, query, map);
        }
        return null;
    }

    private Collection<Triple<ByteString, ByteString, ByteString>> samplePredictionsOneVariable(Set<ByteString> set, Query query, Map<ByteString, Map<ByteString, Set<ByteString>>> map) {
        return null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v16, types: [T, F, java.lang.Object, javatools.datatypes.ByteString] */
    /* JADX WARN: Type inference failed for: r0v25, types: [T, F, java.lang.Object, javatools.datatypes.ByteString] */
    private Collection<Triple<ByteString, ByteString, ByteString>> samplePredictionsTwoVariables(Map<ByteString, IntHashMap<ByteString>> map, Query query, Map<ByteString, Map<ByteString, Set<ByteString>>> map2) {
        Set<ByteString> keySet = map.keySet();
        S s = query.getHead()[1];
        int functionalVariablePosition = query.getFunctionalVariablePosition();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (ByteString byteString : keySet) {
            Iterator<ByteString> it = map.get(byteString).iterator();
            while (it.hasNext()) {
                ByteString next = it.next();
                Triple<ByteString, ByteString, ByteString> triple = new Triple<>(null, null, null);
                if (!byteString.equals(next)) {
                    if (functionalVariablePosition == 0) {
                        triple.first = byteString;
                        triple.third = next;
                    } else {
                        triple.first = next;
                        triple.third = byteString;
                    }
                    triple.second = s;
                    if (!containsPrediction(map2, triple)) {
                        linkedHashSet.add(triple);
                    }
                    addPrediction(map2, triple);
                }
            }
        }
        return sample(linkedHashSet, this.sampleSize);
    }

    private void addPrediction(Map<ByteString, Map<ByteString, Set<ByteString>>> map, Triple<ByteString, ByteString, ByteString> triple) {
        if (!map.containsKey(triple.second)) {
            HashSet hashSet = new HashSet();
            hashSet.add(triple.third);
            HashMap hashMap = new HashMap();
            hashMap.put(triple.first, hashSet);
            map.put(triple.second, hashMap);
            return;
        }
        Map<ByteString, Set<ByteString>> map2 = map.get(triple.second);
        if (map2.containsKey(triple.first)) {
            map2.get(triple.first).add(triple.third);
            return;
        }
        HashSet hashSet2 = new HashSet();
        hashSet2.add(triple.third);
        map2.put(triple.first, hashSet2);
    }

    private boolean containsPrediction(Map<ByteString, Map<ByteString, Set<ByteString>>> map, Triple<ByteString, ByteString, ByteString> triple) {
        Set<ByteString> set;
        Map<ByteString, Set<ByteString>> map2 = map.get(triple.second);
        if (map2 == null || (set = map2.get(triple.first)) == null) {
            return false;
        }
        return set.contains(triple.third);
    }

    public static Collection<Triple<ByteString, ByteString, ByteString>> sample(Collection<Triple<ByteString, ByteString, ByteString>> collection, int i) {
        ArrayList arrayList = new ArrayList(i);
        if (collection.size() <= i) {
            return collection;
        }
        Object[] array = collection.toArray();
        Random random = new Random();
        int i2 = 0;
        while (i2 < i) {
            arrayList.add((Triple) array[i2]);
            i2++;
        }
        while (i2 < array.length) {
            if (random.nextInt(i2) < i) {
                arrayList.set(random.nextInt(i), (Triple) array[i2]);
            }
            i2++;
        }
        return arrayList;
    }

    public static void main(String[] strArr) throws IOException {
        FactDatabase factDatabase = new FactDatabase();
        if (strArr.length < 4) {
            System.err.println("PredictionsSampler <db> <samplesPerRule> <unique> <rules>");
            System.err.println("db\tAn RDF knowledge base");
            System.err.println("samplesPerRule\tSample size per rule. It defines the number of facts that will be randomly taken from the entire set of predictions made a each rule");
            System.err.println("unique (0|1)\tIf 1, predictions that were generated by other rules before, are not output");
            System.err.println("rules\tFile containing each rule per line, as they are output by AMIE.");
            System.exit(1);
        }
        factDatabase.load(new File(strArr[0]));
        int parseInt = Integer.parseInt(strArr[1]);
        int parseInt2 = Integer.parseInt(strArr[2]);
        ArrayList<Query> arrayList = new ArrayList();
        for (int i = 3; i < strArr.length; i++) {
            arrayList.addAll(AMIEreader.rules(new File(strArr[i])));
        }
        for (Query query : arrayList) {
            if (factDatabase.functionality(query.getHead()[1]) >= factDatabase.inverseFunctionality(query.getHead()[1])) {
                query.setProjectionVariable(query.getHead()[0]);
            } else {
                query.setProjectionVariable(query.getHead()[2]);
            }
        }
        PredictionsSampler predictionsSampler = new PredictionsSampler(factDatabase, parseInt);
        if (parseInt2 == 1) {
            predictionsSampler.runMode1(arrayList);
        } else {
            predictionsSampler.runMode2(arrayList);
        }
    }
}
