import { Entity, ENTITY_TYPES, EntityHelpers, EntityRole } from '@luis/entities';
import { BehaviorSubject, Observable } from 'rxjs/Rx';
import { emptyResultFactory, IModelConfusionAggregate } from '../interfaces/IModelConfusionAggregate';
import { DatasetResultItem } from './dataset-result-item.model';
import { CONFUSION_TYPE } from './entity-confusion.model';
import { LuisModel } from '@luis/core';

/**
 * @description
 * Represents a helper class that does all the heavy lifting for
 * the dataset results and tables components. Provides metrics
 * such as overall correctness of the utterances in the test,
 * detailed correctness of each model in the test, etc.
 */
export class DatasetScorer {
	private readonly _includePrebuilts: BehaviorSubject<boolean> = new BehaviorSubject<boolean>(false);

	constructor(private readonly _resultItems: DatasetResultItem[], private readonly _entities: Entity[]) {}

	/**
	 * @description
	 * Gets the raw result items that returned from the running of
	 * the dataset on the server.
	 *
	 * @returns The results of the dataset running.
	 */
	public get resultItems(): DatasetResultItem[] {
		return this._resultItems;
	}

	/**
	 * @description
	 * Gets the overall number of utterances that passed both
	 * in intents, entities and roles.
	 *
	 * @returns An observable of the count.
	 */
	public get overallCorrectCount(): Observable<number> {
		return this._includePrebuilts
			.asObservable()
			.map(showPrebuilts => {
				const entityRoles = EntityHelpers.extractRoles(this._entities);

				return showPrebuilts
					? [...entityRoles, ...this._entities]
					: [...entityRoles, ...this._entities.filter(e => e.type !== ENTITY_TYPES.PREBUILT)];
			})
			.map(entitiesToShow => entitiesToShow.map(e => e.id))
			.map(ids => this._getCorrectResultItems(this._resultItems, ids));
	}

	/**
	 * @description
	 * Gets the confusion scores for each model aggregated
	 * on all the utterances of this dataset.
	 *
	 * @returns A map of model id and confusion aggregation
	 * object that contains the count of each confusion type
	 * this model registered for each utterance.
	 */
	public get modelConfusionAggregates(): Observable<Map<string, IModelConfusionAggregate>> {
		return this._includePrebuilts.asObservable().map(
			showPrebuilts =>
				new Map<string, IModelConfusionAggregate>([
					...this._getIntentConfustions(this._resultItems),
					...this._getEntityOrRoleConfusions(this._resultItems, this._entities, showPrebuilts),
					...this._getEntityOrRoleConfusions(this._resultItems, EntityHelpers.extractRoles(this._entities))
				])
		);
	}

	/**
	 * @description
	 * Gets the total confusion counts for each model that was
	 * tested accross all of the dataset utterances.
	 *
	 * @returns A map of the model if and the count of all the confusions for that model.
	 */
	public get totalModelConfusionAggregates(): Observable<Map<string, number>> {
		const getTotalCount = (id: string, map: Map<string, IModelConfusionAggregate>) =>
			map.get(id).tp + map.get(id).tn + map.get(id).fp + map.get(id).fn;

		return this.modelConfusionAggregates.map(aggregates => this._confusionAggregatorRunner(aggregates, getTotalCount));
	}

	/**
	 * @description
	 * Gets the correct confusion counts for each model that was
	 * tested accross all of the dataset utterances.
	 *
	 * @returns A map of model id and the count of all the correct
	 * confusions that were correct.
	 */
	public get correctModelConfusionAggregates(): Observable<Map<string, number>> {
		const getCorrectCount = (id: string, map: Map<string, IModelConfusionAggregate>) => map.get(id).tp + map.get(id).tn;

		return this.modelConfusionAggregates.map(aggregates => this._confusionAggregatorRunner(aggregates, getCorrectCount));
	}

	/**
	 * @description
	 * Updates the stream whether to include prebuilts or not.
	 */
	public set includePrebuilts(value: boolean) {
		this._includePrebuilts.next(value);
	}

	/**
	 * @description
	 * Gets the overall number of correct utterances that passed both
	 * intents, entities and roles confusions. if an utterance fails in any of
	 * them along the way then it is not counted.
	 *
	 * @param items The results from the dataset running.
	 * @param entityIdsToInclude The ids to check from the confusion scores.
	 * This parameter is useful when we want to exclude some entities from
	 * the score calculations, eg: prebuilt entities.
	 * @returns A count of the overall number of utterances that succeeded fully.
	 */
	private _getCorrectResultItems(items: DatasetResultItem[], entityIdsToInclude: string[]): number {
		return items.reduce((acc, item) => {
			if (item.utterance.labeledIntent.name !== item.utterance.predictedIntents[0].name) {
				return acc;
			}

			const entityIds = Array.from(item.entityConfusions.keys()).filter(id => entityIdsToInclude.indexOf(id) !== -1);

			if (entityIds.length > 0) {
				const incorrectConfusion = entityIds
					.map(id => item.entityConfusions.get(id))
					.reduce((a, b) => a.concat(b))
					.find(c => c.type === CONFUSION_TYPE.FALSE_POSITIVE || c.type === CONFUSION_TYPE.FALSE_NEGATIVE);

				if (incorrectConfusion !== undefined) {
					return acc;
				}
			}

			return acc + 1;
		}, 0);
	}

	/**
	 * @description
	 * Aggregates the confusion scores for each intent for each item. That is,
	 * this function creates a map of intent id and a confusion aggregation
	 * object that contains the counts of all true positives, true negatives,
	 * etc. from all the utterances that were included in the dataset.
	 *
	 * @param items The items that resulted from the dataset being run
	 * @returns A map of intent id and the confusion aggregation object.
	 */
	private _getIntentConfustions(items: DatasetResultItem[]): Map<string, IModelConfusionAggregate> {
		return items.reduce((acc, item) => {
			item.utterance.predictedIntents.map((pI, index) => {
				const result: IModelConfusionAggregate = acc.get(pI.id) || emptyResultFactory();
				let confusion: number;

				// Intent was labeled and predicted correctly. True positive.
				if (pI.id === item.utterance.labeledIntent.id && index === 0) {
					result.tp = result.tp + 1;
					confusion = CONFUSION_TYPE.TRUE_POSITIVE;
					item.utterance.labeledIntent.score = pI.score;
				}
				// Intent was labeled but predicted incorrectly. False negative.
				else if (pI.id === item.utterance.labeledIntent.id && index !== 0) {
					result.fn = result.fn + 1;
					confusion = CONFUSION_TYPE.FALSE_NEGATIVE;
					item.utterance.labeledIntent.score = pI.score;
				}
				// Intent was not labeled and was predicted incorrectly. False positive.
				else if (pI.id !== item.utterance.labeledIntent.id && index === 0) {
					result.fp = result.fp + 1;
					confusion = CONFUSION_TYPE.FALSE_POSITIVE;
				}
				// Intent was not labeled and was not predicted. True negative.
				else if (pI.id !== item.utterance.labeledIntent.id && index !== 0) {
					result.tn = result.tn + 1;
					confusion = CONFUSION_TYPE.TRUE_NEGATIVE;
				}

				result.utteranceScores.set(item.utterance.id, [{ score: pI.score, confusion: confusion }]);
				acc.set(pI.id, result);
			});

			return acc;
		}, new Map<string, IModelConfusionAggregate>());
	}

	/**
	 * @description
	 * Aggregates the confusion scores for each entity or role for each item. That is,
	 * this function creates a map of model id and a confusion aggregation
	 * object that contains the counts of all true positives, true negatives,
	 * etc. from all the utterances that were included in the dataset.
	 *
	 * @param items The items that resulted from the dataset being run
	 * @returns A map of model (entity or role) id and the confusion aggregation object.
	 */
	private _getEntityOrRoleConfusions(
		items: DatasetResultItem[],
		models: LuisModel[],
		showPrebuilts: boolean = false
	): Map<string, IModelConfusionAggregate> {
		return items.reduce((acc, item) => {
			Array.from(item.entityConfusions.keys()).map(eId => {
				const modelMatch = models.find(e => e.id === eId);

				if (!modelMatch || (!showPrebuilts && modelMatch && modelMatch.type === ENTITY_TYPES.PREBUILT)) {
					return acc;
				}

				const result: IModelConfusionAggregate = acc.get(eId) || emptyResultFactory();

				item.entityConfusions.get(eId).forEach(cI => {
					switch (cI.type) {
						case CONFUSION_TYPE.TRUE_POSITIVE:
							result.tp = result.tp + 1;
							break;
						case CONFUSION_TYPE.TRUE_NEGATIVE:
							result.tn = result.tn + 1;
							break;
						case CONFUSION_TYPE.FALSE_POSITIVE:
							result.fp = result.fp + 1;
							break;
						case CONFUSION_TYPE.FALSE_NEGATIVE:
							result.fn = result.fn + 1;
							break;
						default:
					}
				});

				result.utteranceScores.set(
					item.utterance.id,
					item.entityConfusions.get(eId).map(cI => ({ score: cI.score, confusion: cI.type }))
				);

				acc.set(eId, result);
			});

			return acc;
		}, new Map<string, IModelConfusionAggregate>());
	}

	/**
	 * @description
	 * Aggregates the confusions to model id and a count, according to the given
	 * aggregation function.
	 *
	 * @param confusionScores The map of confusions.
	 * @param aggregationFunction The function to use to aggregate the confusions.
	 * @returns A map of the model id and the aggregated confusion scores.
	 */
	private _confusionAggregatorRunner(
		confusionScores: Map<string, IModelConfusionAggregate>,
		aggregationFunction: (id: string, map: Map<string, IModelConfusionAggregate>) => number
	): Map<string, number> {
		const ids = Array.from(confusionScores.keys());

		return ids.reduce((acc, id) => acc.set(id, aggregationFunction(id, confusionScores)), new Map<string, number>());
	}
}
