package amie.data.eval;

import amie.data.FactDatabase;
import amie.query.AMIEreader;
import amie.query.Query;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import javatools.datatypes.ByteString;
import javatools.datatypes.IntHashMap;
import javatools.datatypes.Triple;

/* loaded from: input_file:amie/data/eval/PCAFalseFactsSampler.class */
public class PCAFalseFactsSampler {
    private FactDatabase db;
    private int sampleSize;

    public PCAFalseFactsSampler(FactDatabase factDatabase, int i) {
        this.db = factDatabase;
        this.sampleSize = i;
    }

    public static void main(String[] strArr) throws IOException {
        FactDatabase factDatabase = new FactDatabase();
        if (strArr.length < 3) {
            System.err.println("PCAFalseFactsSampler <db> <samplesPerRule> <rules>");
            System.err.println("db\tAn RDF knowledge base");
            System.err.println("samplesPerRule\tSample size per rule. It defines the number of facts that will be randomly taken from the entire set of predictions made a each rule");
            System.err.println("rules\tFile containing each rule per line, as they are output by AMIE.");
            System.exit(1);
        }
        factDatabase.load(new File(strArr[0]));
        int parseInt = Integer.parseInt(strArr[1]);
        ArrayList<Query> arrayList = new ArrayList();
        for (int i = 2; i < strArr.length; i++) {
            arrayList.addAll(AMIEreader.rules(new File(strArr[i])));
        }
        for (Query query : arrayList) {
            if (factDatabase.functionality(query.getHead()[1]) >= factDatabase.inverseFunctionality(query.getHead()[1])) {
                query.setProjectionVariable(query.getHead()[0]);
            } else {
                query.setProjectionVariable(query.getHead()[2]);
            }
        }
        new PCAFalseFactsSampler(factDatabase, parseInt).run(arrayList);
    }

    private void run(Collection<Query> collection) {
        for (Query query : collection) {
            for (Triple<ByteString, ByteString, ByteString> triple : PredictionsSampler.sample(generateAssumedFalseFacts(query), this.sampleSize)) {
                System.out.println(String.valueOf(query.getRuleString()) + "\t" + ((Object) triple.first) + "\t" + ((Object) triple.second) + "\t" + ((Object) triple.third));
            }
        }
    }

    private void runAndGroupByRelation(Collection<Query> collection) {
        HashMap hashMap = new HashMap();
        for (Query query : collection) {
            Collection collection2 = (Collection) hashMap.get(query.getHead()[1]);
            if (collection2 == null) {
                collection2 = new ArrayList();
                hashMap.put(query.getHead()[1], collection2);
            }
            collection2.add(query);
        }
        for (ByteString byteString : hashMap.keySet()) {
            HashMap hashMap2 = new HashMap();
            LinkedHashSet linkedHashSet = new LinkedHashSet();
            for (Query query2 : (Collection) hashMap.get(byteString)) {
                Set<Triple<ByteString, ByteString, ByteString>> generateAssumedFalseFacts = generateAssumedFalseFacts(query2);
                Iterator<Triple<ByteString, ByteString, ByteString>> it = generateAssumedFalseFacts.iterator();
                while (it.hasNext()) {
                    hashMap2.put(it.next(), query2);
                }
                linkedHashSet.addAll(generateAssumedFalseFacts);
            }
            printSample(PredictionsSampler.sample(linkedHashSet, this.sampleSize), hashMap2);
        }
    }

    private void printSample(Collection<Triple<ByteString, ByteString, ByteString>> collection, Map<Triple<ByteString, ByteString, ByteString>, Query> map) {
        for (Triple<ByteString, ByteString, ByteString> triple : collection) {
            System.out.println(String.valueOf(map.get(triple).getRuleString()) + "\t" + ((Object) triple.first) + "\t" + ((Object) triple.second) + "\t" + ((Object) triple.third));
        }
    }

    private Set<Triple<ByteString, ByteString, ByteString>> generateAssumedFalseFacts(Query query) {
        ArrayList arrayList = new ArrayList();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        ByteString[] head = query.getHead();
        ByteString[] byteStringArr = (ByteString[]) head.clone();
        ByteString byteString = head[1];
        if (query.getFunctionalVariablePosition() == 0) {
            byteStringArr[2] = ByteString.of("?x");
        } else {
            byteStringArr[0] = ByteString.of("?x");
        }
        arrayList.add(byteStringArr);
        Iterator<ByteString[]> it = query.getAntecedent().iterator();
        while (it.hasNext()) {
            arrayList.add((ByteString[]) it.next().clone());
        }
        if (FactDatabase.numVariables(query.getHead()) == 2) {
            Map<ByteString, IntHashMap<ByteString>> difference = this.db.difference(head[0], head[2], arrayList, query.getTriples());
            for (ByteString byteString2 : difference.keySet()) {
                Iterator<ByteString> it2 = difference.get(byteString2).iterator();
                while (it2.hasNext()) {
                    linkedHashSet.add(new Triple(byteString2, byteString, it2.next()));
                }
            }
        }
        return linkedHashSet;
    }
}
