package ai.grakn.graql.internal.parser;

import ai.grakn.concept.AttributeType;
import ai.grakn.exception.GraqlQueryException;
import ai.grakn.exception.GraqlSyntaxException;
import ai.grakn.graql.Aggregate;
import ai.grakn.graql.Graql;
import ai.grakn.graql.Pattern;
import ai.grakn.graql.Query;
import ai.grakn.graql.QueryBuilder;
import ai.grakn.graql.Var;
import ai.grakn.graql.internal.antlr.GraqlLexer;
import ai.grakn.graql.internal.antlr.GraqlParser;
import ai.grakn.graql.internal.query.aggregate.Aggregates;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Stream;
import org.antlr.v4.runtime.ANTLRInputStream;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.TokenStream;
import org.antlr.v4.runtime.UnbufferedTokenStream;
import org.antlr.v4.runtime.tree.ParseTree;

/* loaded from: input_file:ai/grakn/graql/internal/parser/QueryParser.class */
public class QueryParser {
    private final QueryBuilder queryBuilder;
    private final Map<String, Function<List<Object>, Aggregate>> aggregateMethods = new HashMap();
    public static final ImmutableBiMap<String, AttributeType.DataType> DATA_TYPES = ImmutableBiMap.of("long", AttributeType.DataType.LONG, "double", AttributeType.DataType.DOUBLE, "string", AttributeType.DataType.STRING, "boolean", AttributeType.DataType.BOOLEAN, "date", AttributeType.DataType.DATE);

    private QueryParser(QueryBuilder queryBuilder) {
        this.queryBuilder = queryBuilder;
    }

    public static QueryParser create(QueryBuilder queryBuilder) {
        QueryParser queryParser = new QueryParser(queryBuilder);
        queryParser.registerDefaultAggregates();
        return queryParser;
    }

    private void registerAggregate(String str, int i, Function<List<Object>, Aggregate> function) {
        registerAggregate(str, i, i, function);
    }

    private void registerAggregate(String str, int i, int i2, Function<List<Object>, Aggregate> function) {
        this.aggregateMethods.put(str, list -> {
            if (list.size() < i || list.size() > i2) {
                throw GraqlQueryException.incorrectAggregateArgumentNumber(str, i, i2, list);
            }
            return (Aggregate) function.apply(list);
        });
    }

    public void registerAggregate(String str, Function<List<Object>, Aggregate> function) {
        this.aggregateMethods.put(str, function);
    }

    public <T extends Query<?>> T parseQuery(String str) {
        return (T) parseQueryFragment((v0) -> {
            return v0.queryEOF();
        }, (v0, v1) -> {
            return v0.visitQueryEOF(v1);
        }, str, getLexer(str));
    }

    public <T extends Query<?>> Stream<T> parseList(String str) {
        GraqlLexer lexer = getLexer(str);
        GraqlErrorListener graqlErrorListener = new GraqlErrorListener(str);
        lexer.removeErrorListeners();
        lexer.addErrorListener(graqlErrorListener);
        return ((Stream) parseQueryFragment((v0) -> {
            return v0.queryList();
        }, (v0, v1) -> {
            return v0.visitQueryList(v1);
        }, (TokenStream) new UnbufferedTokenStream(ChannelTokenSource.of(lexer)), graqlErrorListener)).map(query -> {
            return query;
        });
    }

    public List<Pattern> parsePatterns(String str) {
        return (List) parseQueryFragment((v0) -> {
            return v0.patterns();
        }, (v0, v1) -> {
            return v0.visitPatterns(v1);
        }, str, getLexer(str));
    }

    public Pattern parsePattern(String str) {
        return (Pattern) parseQueryFragment((v0) -> {
            return v0.pattern();
        }, (v0, v1) -> {
            return v0.visitPattern(v1);
        }, str, getLexer(str));
    }

    private <T, S extends ParseTree> T parseQueryFragment(Function<GraqlParser, S> function, BiFunction<QueryVisitor, S, T> biFunction, String str, GraqlLexer graqlLexer) {
        GraqlErrorListener graqlErrorListener = new GraqlErrorListener(str);
        graqlLexer.removeErrorListeners();
        graqlLexer.addErrorListener(graqlErrorListener);
        return (T) parseQueryFragment((Function) function, (BiFunction) biFunction, (TokenStream) new CommonTokenStream(graqlLexer), graqlErrorListener);
    }

    private <T, S extends ParseTree> T parseQueryFragment(Function<GraqlParser, S> function, BiFunction<QueryVisitor, S, T> biFunction, TokenStream tokenStream, GraqlErrorListener graqlErrorListener) {
        GraqlParser graqlParser = new GraqlParser(tokenStream);
        graqlParser.removeErrorListeners();
        graqlParser.addErrorListener(graqlErrorListener);
        S apply = function.apply(graqlParser);
        if (graqlErrorListener.hasErrors()) {
            throw GraqlSyntaxException.parsingError(graqlErrorListener.toString());
        }
        return biFunction.apply(getQueryVisitor(), apply);
    }

    private GraqlLexer getLexer(String str) {
        return new GraqlLexer(new ANTLRInputStream(str));
    }

    private QueryVisitor getQueryVisitor() {
        return new QueryVisitor(ImmutableMap.copyOf(this.aggregateMethods), this.queryBuilder);
    }

    private void registerDefaultAggregates() {
        registerAggregate("count", 0, list -> {
            return Graql.count();
        });
        registerAggregate("ask", 0, list2 -> {
            return Graql.ask();
        });
        registerAggregate("sum", 1, list3 -> {
            return Aggregates.sum((Var) list3.get(0));
        });
        registerAggregate("max", 1, list4 -> {
            return Aggregates.max((Var) list4.get(0));
        });
        registerAggregate("min", 1, list5 -> {
            return Aggregates.min((Var) list5.get(0));
        });
        registerAggregate("mean", 1, list6 -> {
            return Aggregates.mean((Var) list6.get(0));
        });
        registerAggregate("median", 1, list7 -> {
            return Aggregates.median((Var) list7.get(0));
        });
        registerAggregate("std", 1, list8 -> {
            return Aggregates.std((Var) list8.get(0));
        });
        registerAggregate("group", 1, 2, list9 -> {
            return list9.size() < 2 ? Aggregates.group((Var) list9.get(0)) : Aggregates.group((Var) list9.get(0), (Aggregate) list9.get(1));
        });
    }
}
