package de.xam.triplerules.impl;

import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;

import org.xydra.index.ITripleIndex;
import org.xydra.index.iterator.IFilter;
import org.xydra.index.iterator.ITransformer;
import org.xydra.index.iterator.Iterators;
import org.xydra.index.query.Constraint;
import org.xydra.index.query.ITriple;
import org.xydra.log.api.Logger;
import org.xydra.log.api.LoggerFactory;

import de.xam.triplerules.IReadonlyTriplePatternSet;
import de.xam.triplerules.IRuleConditionBinding;
import de.xam.triplerules.ITriplePattern;
import de.xam.triplerules.ITripleRule;
import de.xam.triplerules.IVariable;

/**
 * Static helper methods for the {@link RuleEngine}.
 *
 * For rules with n patterns: Try each pattern as first pattern; for each such attempt, create a binding and try to fill
 * the remaining patterns in a clever sequence with the tripleIndex - costs matter. The conditions are commutative, you
 * just need to make sure each one could come from the new inferred triples.
 *
 * @author xamde
 *
 * @param <K>
 * @param <L>
 * @param <M>
 */
public class InferenceEngine<K extends Serializable, L extends Serializable, M extends Serializable> {

	// use logger on {@link de.xam.triplerules.impl.InfLog} to see a reasoning trace
	private static final Logger log = LoggerFactory.getLogger(InferenceEngine.class);

	/**
	 * Explode a rule condition for a single binding
	 *
	 * @param rule
	 * @param sourceTripleIndex
	 * @param binding
	 * @param costEstimator
	 * @return an iterator over multiple, more defined bindings (i.e. more bound variables)
	 */
	private static <K, L, M> Iterator<ConditionBinding<K, L, M>> explodeBindingRecursively(
			final ITripleRule<K, L, M> rule, final ITripleIndex<K, L, M> sourceTripleIndex,
			final ConditionBinding<K, L, M> binding, final ICostEstimator<K, L, M> costEstimator) {
		assert binding.capacity() == rule.varCount();

		if (binding.unboundPatterns().isEmpty()) {
			// end recursion
			return Iterators.forOne(binding);
		}

		// explode all __remaining patterns__ in a clever sequence (commutative)
		if (log.isDebugEnabled()) {
			log.debug(org.xydra.sharedutils.DebugUtils.toIndent("==", binding.size() + 1) + "Explode binding: "
					+ binding.toString(rule));
		}
		final ITriplePattern<K, L, M> cheapestPattern = getCheapestUnboundPattern(rule.condition(), binding,
				costEstimator);
		assert cheapestPattern != null;
		return explodeBindingRecursively_OnePattern(rule, cheapestPattern, sourceTripleIndex, binding, costEstimator);
	}

	/**
	 * Explode a single pattern for a single binding
	 *
	 * @param rule
	 * @param pattern
	 * @param sourceTripleIndex
	 * @param binding
	 * @return ...
	 */
	private static <K, L, M> Iterator<ConditionBinding<K, L, M>> explodeBindingRecursively_OnePattern(
			final ITripleRule<K, L, M> rule, final ITriplePattern<K, L, M> pattern,
			final ITripleIndex<K, L, M> sourceTripleIndex, final ConditionBinding<K, L, M> binding,
			final ICostEstimator<K, L, M> costEstimator) {

		if (log.isDebugEnabled()) {
			log.debug(org.xydra.sharedutils.DebugUtils.toIndent("**", binding.size() + 1) + "  " + "Explode pattern: "
					+ pattern.toString(binding) + " from " + rule.condition().toString(binding) + "; binding="
					+ binding.toString());
		}

		// implicit join happening via query
		final IVariable<K> cS = toConstraint(rule, pattern.s(), binding);
		final IVariable<L> cP = toConstraint(rule, pattern.p(), binding);
		final IVariable<M> cO = toConstraint(rule, pattern.o(), binding);
		final Iterator<? extends ITriple<K, L, M>> triplesFromIndex = query(sourceTripleIndex, cS, cP, cO, binding);

		final Iterator<ConditionBinding<K, L, M>> extendedBindings = Iterators.transform(triplesFromIndex,
				new ITransformer<ITriple<K, L, M>, ConditionBinding<K, L, M>>() {

					@Override
					public ConditionBinding<K, L, M> transform(final ITriple<K, L, M> triple) {
						if (log.isDebugEnabled()) {
							log.debug(org.xydra.sharedutils.DebugUtils.toIndent("..", binding.size() + 1) + "    "
									+ "Matching triple: " + triple);
						}
						// add new bindings
						final ConditionBinding<K, L, M> wrapped = ConditionBinding.createRefined(binding, pattern);
						if (cS.isStar()) {
							wrapped.setValue(cS, triple.getKey1());
						}
						if (cP.isStar()) {
							wrapped.setValue(cP, triple.getKey2());
						}
						if (cO.isStar()) {
							wrapped.setValue(cO, triple.getEntry());
						}

						if (log.isTraceEnabled()) {
							log.trace("Created another binding " + wrapped);
						}

						return wrapped;
					}
				});
		// recursion!
		return explodeBindingsRecursively(rule, sourceTripleIndex, extendedBindings, costEstimator);
	}

	/**
	 * Used in iteration 0 + n.
	 *
	 * Recursive reduction step;
	 *
	 * Takes into account different ways of choosing the next pattern
	 *
	 * @param rule
	 * @param sourceTripleIndex
	 * @param bindings
	 * @param costEstimator
	 * @return ...
	 */
	public static <K, L, M> Iterator<ConditionBinding<K, L, M>> explodeBindingsRecursively(
			final ITripleRule<K, L, M> rule, final ITripleIndex<K, L, M> sourceTripleIndex,
			final Iterator<ConditionBinding<K, L, M>> bindings, final ICostEstimator<K, L, M> costEstimator) {
		return Iterators.cascade(bindings,
				new ITransformer<ConditionBinding<K, L, M>, Iterator<ConditionBinding<K, L, M>>>() {

					@Override
					public Iterator<ConditionBinding<K, L, M>> transform(final ConditionBinding<K, L, M> binding) {

						if (log.isTraceEnabled()) {
							log.trace("Exploding " + binding);
						}

						return explodeBindingRecursively(rule, sourceTripleIndex, binding, costEstimator);
					}
				});
	}

	/**
	 * @param condition
	 * @param binding
	 * @param costEstimator
	 * @return @CanBeNull if all patterns are bound
	 */
	public static <K, L, M> ITriplePattern<K, L, M> getCheapestUnboundPattern(
			final IReadonlyTriplePatternSet<K, L, M> condition, final ConditionBinding<K, L, M> binding,
			final ICostEstimator<K, L, M> costEstimator) {
		ITriplePattern<K, L, M> bestPattern = binding.unboundPatterns().iterator().next();
		double lowestCosts = Double.MAX_VALUE;
		for (final ITriplePattern<K, L, M> pattern : binding.unboundPatterns()) {
			final double estimatedCosts = costEstimator.estimatedCosts(pattern, binding);
			if (log.isTraceEnabled()) {
				log.trace("Cost of " + pattern.toString(binding) + " = " + estimatedCosts);
			}
			if (estimatedCosts < lowestCosts) {
				bestPattern = pattern;
				lowestCosts = estimatedCosts;
			}
		}
		return bestPattern;
	}

	/**
	 * @param condition
	 * @param costEstimator
	 * @return @NeverNull
	 */
	public static <K, L, M> ITriplePattern<K, L, M> getCheapestPattern(
			final IReadonlyTriplePatternSet<K, L, M> condition, final ICostEstimator<K, L, M> costEstimator) {
		assert condition.patterns().size() > 0;
		// choose any pattern randomly
		ITriplePattern<K, L, M> bestPattern = condition.patterns().iterator().next();
		double lowestCosts = Double.MAX_VALUE;
		for (final ITriplePattern<K, L, M> pattern : condition.patterns()) {
			final double estimatedCosts = costEstimator.estimatedCosts(pattern);
			if (log.isTraceEnabled()) {
				log.trace("Cost of " + pattern.toString() + " = " + estimatedCosts);
			}
			if (estimatedCosts > 0 && estimatedCosts < lowestCosts) {
				bestPattern = pattern;
				lowestCosts = estimatedCosts;
			}
		}
		return bestPattern;
	}

	// /**
	// * Take the set of AND-ed {@link ITriplePattern}s and reduce them: Replace
	// * all variables in one pattern with an iterator of defined values. This
	// * might also define values for other patterns, if a variable occurs in
	// more
	// * than one pattern. The next step is to select the next triple pattern in
	// * which to replace variables with an iterator of defined values.
	// *
	// * @param rule
	// * @param tripleIndex
	// * @param costEstimator
	// * @return all possible bindings of the given condition; returned iterator
	// * has bindings for all variables in the rule condition
	// */
	//
	// /**
	// * Each step adds at least one variable to the returned binding. So the
	// * recursion ends when all variables are bound.
	// *

	/**
	 * @param ruleManager
	 * @param rule
	 * @param triple
	 * @param tripleIndex
	 * @param newInf is filled with new triples
	 * @param iteration
	 * @param costEstimator
	 * @return number of inferred triples
	 */
	public static <K extends Serializable, L extends Serializable, M extends Serializable> int inferOneIteration_OneRule_OneTriple(
			final IRuleManager<K, L, M> ruleManager, final ITripleRule<K, L, M> rule, final ITriple<K, L, M> triple,
			final ITripleIndex<K, L, M> tripleIndex, final Collection<ITriple<K, L, M>> newInf, final int iteration,
			final ICostEstimator<K, L, M> costEstimator) {
		int inferredTriplesCount = 0;

		/* check if __new triple__ matches __any__ pattern of a rule condition */
		for (final ITriplePattern<K, L, M> pattern : rule.condition().patterns()) {

			if (log.isTraceEnabled()) {
				log.trace(">>> Checking pattern:  " + pattern);
			}

			if (!matches(pattern, triple)) {
				continue;
			}

			if (log.isTraceEnabled()) {
				log.trace(">>>> Pattern matches " + triple);
			}

			// compute binding for matching pattern
			final ConditionBinding<K, L, M> firstBinding = ConditionBinding.createInitialBinding(rule, pattern, triple);

			/* recursively reduce the number of unbound patterns; choose cheapest such pattern first */

			// might be empty
			final Iterator<ConditionBinding<K, L, M>> bindings = explodeBindingRecursively(rule, tripleIndex,
					firstBinding, costEstimator);

			if (bindings.hasNext()) {
				ruleManager.markRuleAsMatched(rule);

				if (log.isTraceEnabled()) {
					log.trace(">>>>> Bindings found (patterns used conjunctively): " + pattern.toString(firstBinding));
				}
			} else {
				if (log.isTraceEnabled()) {
					log.trace(">>>>> No bindings found (patterns used conjunctively)");
				}
			}

			while (bindings.hasNext()) {
				final IRuleConditionBinding<K, L, M> binding = bindings.next();

				if (log.isTraceEnabled()) {
					log.trace(">>>>>> Binding: " + binding);
				}

				// write into newInf
				inferredTriplesCount += RuleUtils.materialiseTriples(rule, binding, tripleIndex, newInf);
			}
		}
		return inferredTriplesCount;
	}

	/**
	 * @param pattern
	 * @param triple
	 * @return true if pattern matches triple
	 */
	public static <K, L, M> boolean matches(final ITriplePattern<K, L, M> pattern, final ITriple<K, L, M> triple) {

		final K s = triple.getKey1();
		final L p = triple.getKey2();
		final M o = triple.getEntry();

		if (pattern.s().isStar()) {
			// need to cross check with other vars
			if (pattern.p().isStar()) {
				if (pattern.s().name().equals(pattern.p().name())) {
					assert pattern.s().isStar();
					assert pattern.p().isStar();
					/**
					 * <pre>
					 * ?x ?x ?x
					 * ?x ?x ?c
					 * ?x ?x  c
					 * </pre>
					 */
					if (pattern.p().name().equals(pattern.o().name())) {
						return s.equals(p) && p.equals(o);
					}
					if (pattern.o().isStar()) {
						return s.equals(p);
					} else {
						return s.equals(p) && pattern.o().matches(o);
					}
				} else {
					assert pattern.s().isStar();
					assert pattern.p().isStar();
					/**
					 * <pre>
					 * ?x ?b ?x
					 * ?a ?x ?x
					 * ?a ?b ?c
					 * ?a ?b  c
					 * </pre>
					 */
					if (pattern.s().name().equals(pattern.o().name())) {
						// ?x ?b ?x
						return s.equals(o);
					} else if (pattern.p().name().equals(pattern.o().name())) {
						// ?a ?x ?x
						return p.equals(o);
					} else {
						if (pattern.o().isStar()) {
							// ?a ?b ?c
							return true;
						} else {
							// ?a ?b c
							return pattern.o().matches(o);
						}
					}
				}

			} else {
				assert pattern.s().isStar();
				assert pattern.p().isExact();
				/**
				 * <pre>
				 * ?x  b ?x
				 * ?a  b ?c
				 * ?a  b  c
				 * </pre>
				 */
				if (pattern.s().name().equals(pattern.o().name())) {
					return s.equals(o) && pattern.p().matches(p);
				} else {
					if (pattern.o().isStar()) {
						// ?a b ?c
						return pattern.p().matches(p);
					} else {
						// ?a b c
						return pattern.p().matches(p) && pattern.o().matches(o);
					}
				}
			}
		} else {
			assert pattern.s().isExact();
			if (pattern.p().isStar()) {
				/**
				 * <pre>
				 *  a ?x ?x
				 *  a ?b ?c
				 *  a ?b  c
				 * </pre>
				 */
				if (pattern.p().name().equals(pattern.o().name())) {
					// a ?x ?x
					return pattern.s().matches(s) && p.equals(o);
				} else {
					if (pattern.o().isStar()) {
						// a ?b ?c
						return pattern.s().matches(s);
					} else {
						// a ?b c
						return pattern.s().matches(s) && pattern.o().matches(o);
					}
				}
			} else {
				assert pattern.p().isExact();
				/**
				 * <pre>
				 *  a  b ?c
				 *  a  b  c
				 * </pre>
				 */
				if (pattern.o().isStar()) {
					return pattern.s().matches(s) && pattern.p().matches(p);
				} else {
					return pattern.s().matches(s) && pattern.p().matches(p) && pattern.o().matches(o);
				}
			}
		}
	}

	// /**
	// * @param rule required to lookup mapping varName-varId
	// * @param pattern
	// * @param binding
	// * @param sourceTripleIndex needs {@link ITripleIndex#contains(Constraint, Constraint, Constraint)}
	// * @return true if pattern matches in tripleIndex
	// */
	// @SuppressWarnings("unused")
	// private static <K, L, M> boolean matches(final ITripleRule<K, L, M> rule, final ITriplePattern<K, L, M> pattern,
	// final ConditionBinding<K, L, M> binding, final ITripleIndex<K, L, M> sourceTripleIndex) {
	// final Constraint<K> cS = toConstraint(rule, pattern.s(), binding);
	// final Constraint<L> cP = toConstraint(rule, pattern.p(), binding);
	// final Constraint<M> cO = toConstraint(rule, pattern.o(), binding);
	// return sourceTripleIndex.contains(cS, cP, cO);
	// }

	/**
	 * @param sourceTripleIndex needs {@link ITripleIndex#getTriples(Constraint, Constraint, Constraint)}
	 * @param cS
	 * @param cP
	 * @param cO
	 * @param binding
	 * @return
	 */
	private static <K, L, M> Iterator<? extends ITriple<K, L, M>> query(final ITripleIndex<K, L, M> sourceTripleIndex,
			final IVariable<K> cS, final IVariable<L> cP, final IVariable<M> cO,
			final ConditionBinding<K, L, M> binding) {
		/* query index for triples matching the refined pattern */
		Iterator<? extends ITriple<K, L, M>> triplesFromIndex = sourceTripleIndex.getTriples(cS, cP, cO);

		if (log.isTraceEnabled()) {
			final List<? extends ITriple<K, L, M>> list = Iterators.toList(triplesFromIndex);
			if (list.isEmpty()) {
				log.debug(org.xydra.sharedutils.DebugUtils.toIndent("__", binding.size() + 1) + "  No match found for ("
						+ cS + ", " + cP + ", " + cO + ") from " + sourceTripleIndex.getClass().getName());
			} else {
				log.debug(org.xydra.sharedutils.DebugUtils.toIndent("__", binding.size() + 1) + "  " + list.size()
						+ " matches for (" + cS + ", " + cP + ", " + cO + ") from "
						+ sourceTripleIndex.getClass().getName());
				for (final ITriple<K, L, M> t : list) {
					log.trace("Triple from index: " + t);
					assert t.getKey1() != null;
					assert t.getKey2() != null;
					assert t.getEntry() != null;
				}
			}

			triplesFromIndex = sourceTripleIndex.getTriples(cS, cP, cO);
		}

		return triplesFromIndex;
	}

	/**
	 * @param rule
	 * @param pattern
	 * @param sourceTripleIndex needs {@link ITripleIndex#getTriples(Constraint, Constraint, Constraint)}
	 * @return
	 */
	public static <K, L, M> Iterator<? extends ITriple<K, L, M>> query(final ITripleRule<K, L, M> rule,
			final ITriplePattern<K, L, M> pattern, final ITripleIndex<K, L, M> sourceTripleIndex) {
		assert pattern != null;

		if (log.isDebugEnabled()) {
			log.debug("Query triple index: " + pattern);
		}

		final Iterator<ITriple<K, L, M>> base = sourceTripleIndex.getTriples(pattern.s(), pattern.p(), pattern.o());
		final IFilter<ITriple<K, L, M>> filter = new IFilter<ITriple<K, L, M>>() {

			@Override
			public boolean matches(final ITriple<K, L, M> entry) {
				return InferenceEngine.matches(pattern, entry);
			}
		};
		return Iterators.filter(base, filter);
	}

	/**
	 * @param var
	 * @param binding
	 * @return the variable, maybe restricted to already known bindings
	 */
	@SuppressWarnings("unchecked")
	private static <K, L, M, E> IVariable<E> toConstraint(final ITripleRule<?, ?, ?> rule, final IVariable<E> var,
			final ConditionBinding<K, L, M> binding) {
		if (var.isStar()) {
			final Object o = binding.boundValue(var);
			if (o == null) {
				return var;
			} else {
				// already defined in binding, refine to this value
				final EqualsVariable<E> newVar = new EqualsVariable<E>(var.name(), (E) o);
				newVar.compile(rule);
				return newVar;
			}
		} else {
			return var;
		}
	}

}
