% Nivre's parser using D-rules or classifiers
% Author: Pierre Nugues
% Implementation improved from the book version
% Language processing using Perl and Prolog, Springer, 2006
% To load a corpus and extract the action list
% 1. Load the corpus file with the corpus/1 predicate
% 2. Call corpus(Corpus), parse_corpus(Corpus, File).
% The Corpus fact is created from a CoNLL formatted file using 
% convert_conll_clause.pl

:- encoding(iso_latin_1).

% Top-level predicate to parse a corpus
% write the output in File
% parse_corpus(+Corpus, +File)
% Can use two oracles with: D-rules or classifiers
% oracle_drules or oracle_ml
parse_corpus(Corpus, File) :-
	open(File, write, Stream, [encoding(utf8)]),
	set_output(Stream),
	parse_corpus_x(Corpus, 0),
	close(Stream).

parse_corpus_x([], _).
parse_corpus_x([X | L], N) :-
	N1 is N + 1,
%	write(N), write(': '),
	make_word(X, W),
	nivre_parser([w([id=0, form=root, postag='ROOT']) | W], _, G),
	print_graph(W, G),
	!,
	parse_corpus_x(L, N1).

% Formats the graphs using the CoNLL 2006 format
% Words that have no head are assigned the ROOT head.
% print_graph(+Words, +Graph)
print_graph([], _) :- nl.
print_graph([w([id=ID, form=FORM, postag=POSTAG]) | W], G) :-
	member(w([id=ID, form=FORM, postag=POSTAG, head=HEAD, deprel=DEPREL]), G),
	writef('%w\t%w\t%w\t%w\t%w\t%w\t%w\t%w\n', 
	       [ID, FORM, '_', POSTAG, POSTAG, '_', HEAD, DEPREL]),
	print_graph(W, G).
print_graph([w([id=ID, form=FORM, postag=POSTAG]) | W], G) :-
	\+ member(w([id=ID, form=FORM | _]), G),
	writef('%w\t%w\t%w\t%w\t%w\t%w\t%w\t%w\n', 
	       [ID, FORM, '_', POSTAG, POSTAG, '_', 0, 'ROOT']),
	print_graph(W, G).

% Formats the words as a subset of the CoNLL 2006 format
% make_word(+W, -FormattedWord)
make_word([], []).
make_word([w(X) | L], [w([id=ID, form=FORM, postag=POSTAG]) | WL]) :-
	member(id=ID, X),
	member(form=FORM, X),
	member(postag=POSTAG, X),
	make_word(L, WL).

% Nivre's Dependency Shift reduce
% We store the words as 
% [w([id='0', form='ROOT']), w([id='1', form='Ãktenskapet']), w(), ...]
% We store the dependency arcs as d(Head, Dependent, Function).
% [w([id='1', form='Ãktenskapet', head=4, deprel='SS']), ...]

% nivre_parser(+Sentence, ?Operations, ?Graph)
nivre_parser(Sentence, Ops, Graph) :-
	nivre_parser(Sentence, [], Ops, [], Graph).
%	all_connected(Sentence, Graph), %Further improvements
%	unique_root(Graph).
nivre_parser(_, [fail], _).

% Auxiliary predicate.
% nivre_parser(+Words, +Stack, -Ops, +CurGraph, -RefGraph)
nivre_parser([], _, [], Graph, Graph).
nivre_parser(Words, Stack, [Op | Ops], Graph, RefGraph) :-
	oracle_ml(Words, Stack, Graph, Op),
%	oracle_drules(Words, Stack, Op),
	execute_action(Op, Words, NWords, Stack, NStack, Graph, NGraph),
	nivre_parser(NWords, NStack, Ops, NGraph, RefGraph).

% execute_action(+Op, +Words, -NewWords, +Stack, -NewStack, +Graph, NewGraph)
% Executes the operation and produces a new parser state
execute_action(la, Words, Words, Stack, NStack, Graph, NGraph) :-
	left_arc(Words, Stack, NStack, Graph, NGraph).
execute_action(ra, Words, NWords, Stack, NStack, Graph, NGraph) :-
	right_arc(Words, NWords, Stack, NStack, Graph, NGraph).
execute_action(re, Words, Words, Stack, NStack, Graph, Graph) :-
	reduce(Stack, NStack, Graph).
execute_action(sh, Words, NWords, Stack, NStack, Graph, Graph) :-
	shift(Words, NWords, Stack, NStack).
execute_action(Op, _, _, _, _, _, _) :-
	\+ member(Op, [la, ra, re, sh]),
	write('Illegal action. Returning'), nl.


% left_arc(+Words, +Stack, -NewStack, +Graph, -NewGraph)
left_arc([W | _], [T | Stack], Stack, Graph, [w([id=IDT, form=FORMT, postag=POST, head=IDW, deprel=_]) | Graph]) :-
	W = w([id=IDW | _]),
	T = w([id=IDT, form=FORMT, postag=POST | _]),
	\+ member(w([id=IDT, form=FORMT | _]), Graph).

% right_arc(+Words, -NewWords, +Stack, -NewStack, +Graph, -NewGraph)
right_arc([W | Words], Words, [T | Stack], [W, T | Stack], Graph, [w([id=IDW, form=FORMW, postag=POSW, head=IDT, deprel=_]) | Graph]) :-
	W = w([id=IDW, form=FORMW, postag=POSW | _]),
	T = w([id=IDT | _]),
	\+ member(w([id=IDW, form=FORMW | _]), Graph).

% reduce(+Stack, -NewStack, +Graph)
reduce([T | Stack], Stack, Graph) :-
	T = w([id=IDT, form=FORMT | _]),
	member(w([id=IDT, form=FORMT | _]), Graph).

% shift(+Words, -NewWords, +Stack, -NewStack)
shift([First | Words], Words, Stack, [First | Stack]).


% The oracle to handle the Weka classifier
% oracle_ml(+Words, +Stack, +Graph, -Operation)
oracle_ml(Words, Stack, Graph, Op) :-
	extract_features(Words, Stack, Graph, [S0, S1, W0, W1, W2, LA, RA, RE, LMS]),
%	choose([S0, S1, W0, W1, W2, LA, RA, RE, LMS], OpAux),
%	choose([S0, S1, W0, W1, W2, LA, RA, RE], OpAux),
	choose([S0, S1, W0, W1, W2, LMS], OpAux),
%	choose([S0, S1, W0, W1, W2], OpAux),
%	choose([S0, S1, W0, W1, LA, RA, RE, LMS], OpAux),
%	choose([S0, S1, W0, W1, LA, RA, RE], OpAux),
%	choose([S0, S1, W0, W1, LMS], OpAux),
%	choose([S0, S1, W0, W1], OpAux),
%	choose([S0, W0, LA, RA, RE, LMS], OpAux),
%	choose([S0, W0, LMS], OpAux),
%	choose([S0, W0, LA, RA, RE], OpAux),
%	choose([S0, W0], OpAux),
	legal_transition([LA, RA, RE], OpAux, Op).

% Checks if the transition is legal
% If not, falls back to the ordering la > ra > re > sh 
% legal_transition([+BoolLA, +BoolRA, +BoolRE], +Op, -OpOut)
legal_transition([true, _, _], la, la).
legal_transition([_, true, _], ra, ra).
legal_transition([_, _, true], re, re).
legal_transition(_, _, sh).


% The oracle to handle D-rules
% oracle_drules(+Words, +Stack, -Operation)
oracle_drules([W | _], [T | _], la) :-
	T = w([id=_, form=_, postag=POST | _]),
	W = w([id=_, form=_, postag=POSW | _]),
	drule(POSW, POST, _, left).
oracle_drules([W | _], [T | _], ra) :-
	T = w([id=_, form=_, postag=POST | _]),
	W = w([id=_, form=_, postag=POSW | _]),
	drule(POST, POSW, _, right).
oracle_drules([_, W | _], [T | _], sh) :-
	T = w([id=_, form=_, postag=POST | _]),
	W = w([id=_, form=_, postag=POSW | _]),
	drule(POST, POSW, _, right).
oracle_drules(_, _, re).
oracle_drules(_, _, sh).


% The dependency rules
% drule(+HeadPOS, +DependentPOS, +Function, +Direction)
% English
/*
drule('ROOT', 'VBD', root, right).
drule('NN', 'DT', determinative, left).
drule('NN', 'JJ', attribute, left).
drule('VBD', 'NN', subject, left).
drule('VBD', 'PRP', subject, left).
drule('VBD', 'NN', object, right).
drule('VBD', 'PRP', object, right).
drule('VBD', 'IN', adv, _).
drule('NN', 'IN', pmod, right).
drule('IN', 'NN', pcomp, right).
*/
% Swedish
% 100 most frequent rules extracted from the CoNLL 2006 training set
drule('AV', 'PO', unknown, right).
drule('NN', 'IK', unknown, right).
drule('VV', 'AJ', unknown, right).
drule('VV', 'NN', unknown, right).
drule('PR', 'VN', unknown, right).
drule('HV', 'PR', unknown, right).
drule('NN', 'PR', unknown, right).
drule('SV', 'PO', unknown, left).
drule('PR', '++', unknown, left).
drule('NN', 'NN', unknown, left).
drule('VN', 'NN', unknown, left).
drule('VV', 'PR', unknown, right).
drule('AV', 'NN', unknown, left).
drule('VV', 'AV', unknown, right).
drule('BV', 'AJ', unknown, right).
drule('QV', 'AB', unknown, right).
drule('AV', 'TP', unknown, right).
drule('AV', 'AJ', unknown, right).
drule('NN', '++', unknown, left).
drule('AJ', 'PO', unknown, left).
drule('VV', 'IK', unknown, right).
drule('AV', 'AB', unknown, left).
drule('SV', 'IP', unknown, right).
drule('AV', '++', unknown, left).
drule('NN', 'AJ', unknown, left).
drule('NN', 'NN', unknown, right).
drule('HV', 'NN', unknown, right).
drule('VN', '++', unknown, left).
drule('VN', 'PR', unknown, right).
drule('HV', 'NN', unknown, left).
drule('AV', 'IP', unknown, right).
drule('VV', 'PR', unknown, left).
drule('SV', 'UK', unknown, left).
drule('VN', 'AJ', unknown, left).
drule('PR', 'ID', unknown, right).
drule('PO', 'PO', unknown, right).
drule('AJ', 'AJ', unknown, right).
drule('VV', 'VN', unknown, left).
drule('AV', 'AB', unknown, right).
drule('VV', '++', unknown, left).
drule('AJ', '++', unknown, left).
drule('VV', 'IC', unknown, right).
drule('PO', 'ID', unknown, right).
drule('HV', 'IP', unknown, right).
drule('QV', 'IP', unknown, right).
drule('HV', 'VV', unknown, right).
drule('MN', 'ID', unknown, right).
drule('NN', 'PO', unknown, left).
drule('VV', 'PO', unknown, left).
drule('AV', 'NN', unknown, right).
drule('AV', 'PR', unknown, right).
drule('AV', 'IK', unknown, right).
drule('FV', 'NN', unknown, right).
drule('AV', 'VV', unknown, right).
drule('QV', 'PO', unknown, left).
drule('PN', 'ID', unknown, right).
drule('VN', 'PO', unknown, left).
drule('PR', 'AB', unknown, left).
drule('PR', 'VV', unknown, right).
drule('AJ', 'AB', unknown, left).
drule('NN', 'IK', unknown, left).
drule('VV', 'VN', unknown, right).
drule('PR', 'PO', unknown, right).
drule('QV', 'NN', unknown, left).
drule('PR', 'NN', unknown, right).
drule('QV', 'VV', unknown, right).
drule('VV', 'IP', unknown, right).
drule('PR', 'PN', unknown, right).
drule('QV', 'AB', unknown, left).
drule('SV', 'NN', unknown, left).
drule('VV', 'IK', unknown, left).
drule('NN', 'VV', unknown, right).
drule('SV', 'VV', unknown, right).
drule('NN', 'TP', unknown, left).
drule('MV', 'VV', unknown, right).
drule('VV', 'VV', unknown, right).
drule('PR', 'AN', unknown, right).
drule('VV', 'UK', unknown, left).
drule('NN', 'AB', unknown, left).
drule('VV', 'IM', unknown, left).
drule('PO', 'VV', unknown, right).
drule('MV', 'IP', unknown, right).
drule('HV', 'AB', unknown, right).
drule('VV', 'AB', unknown, left).
drule('HV', 'AB', unknown, left).
drule('VV', 'PN', unknown, left).
drule('VV', 'AB', unknown, right).
drule('VN', 'EN', unknown, left).
drule('NN', 'RO', unknown, left).
drule('NN', 'VN', unknown, right).
drule('VV', 'NN', unknown, left).
drule('AB', 'ID', unknown, right).
drule('HV', 'PO', unknown, left).
drule('PR', 'PR', unknown, right).
drule('NN', 'UK', unknown, left).
drule('VV', 'PO', unknown, right).
drule('SV', 'AB', unknown, left).
drule('AB', 'AB', unknown, left).
drule('AV', 'PO', unknown, left).
drule('UK', 'ID', unknown, right).
drule('NN', 'EN', unknown, left).
drule('AV', 'UK', unknown, left).

% Feature extraction for the classifier
% Features: POS POS-1 for the words and stack + Boolean constraints
extract_features(Words, Stack, Graph, [S0, S1, W0, W1, W2, LA, RA, RE, LMS]) :-
	word_feat(Words, [W0, W1, W2]),
	stack_feat(Stack, [S0, S1]),
	lmoststack_feat(Stack, Graph, LMS),
	bool_feat(Words, Stack, Graph, [LA, RA, RE]).

word_feat([w([id=_, form=_, postag=POSW])], [POSW, nil, nil]) :- !.
word_feat([w([id=_, form=_, postag=POSW0]), w([id=_, form=_, postag=POSW1])], [POSW0, POSW1, nil]) :- !.
word_feat([w([id=_, form=_, postag=POSW0]), w([id=_, form=_, postag=POSW1]), w([id=_, form=_, postag=POSW2]) | _], [POSW0, POSW1, POSW2]).

stack_feat([], [nil, nil]) :- !.
stack_feat([w([id=_, form=_, postag=POST])], [POST, nil]) :- !.
stack_feat([w([id=_, form=_, postag=POST0]), w([id=_, form=_, postag=POST1]) | _], [POST0, POST1]).

lmoststack_feat([w([id=IDT | _]) | _], Graph, LM) :-
	findall(ID, member(w([id=ID, form=_, postag=_, head=IDT | _]), Graph), L),
	sort(L, [LMID | _]),
	member(w([id=LMID, form=_, postag=LM, head=IDT | _]), Graph),
	!.
lmoststack_feat(_, _, nil).
	
bool_feat(_, [], _, [false, false, false]) :- !.
bool_feat([W | _], [T | _], Graph, [LA, RA, RE]) :-
	T = w([id=IDT | _]),
	W = w([id=IDW | _]),
	(
	  \+ member(w([id=IDT | _]), Graph),
	  LA = true, RE = false
	;   
	  member(w([id=IDT | _]), Graph),
	  LA = false, RE = true
	),
	(
	  \+ member(w([id=IDW | _]), Graph),
	  RA = true
	;   
	  member(w([id=IDW | _]), Graph),
	  RA = false
	).
   



