1 module workspaced.com.dcdext;
2 
3 import dparse.ast;
4 import dparse.lexer;
5 import dparse.parser;
6 import dparse.rollback_allocator;
7 
8 import core.thread;
9 
10 import std.algorithm;
11 import std.array;
12 import std.ascii;
13 import std.file;
14 import std.functional;
15 import std.json;
16 import std.range;
17 import std.string;
18 
19 import workspaced.api;
20 import workspaced.dparseext;
21 import workspaced.com.dcd;
22 
23 import workspaced.visitors.classifier;
24 import workspaced.visitors.methodfinder;
25 
26 @component("dcdext")
27 class DCDExtComponent : ComponentWrapper
28 {
29 	mixin DefaultComponentWrapper;
30 
31 	static immutable CodeRegionProtection[] mixableProtection = [
32 		CodeRegionProtection.public_ | CodeRegionProtection.default_, CodeRegionProtection.package_,
33 		CodeRegionProtection.packageIdentifier, CodeRegionProtection.protected_,
34 		CodeRegionProtection.private_
35 	];
36 
37 	/// Loads dcd extension methods. Call with `{"cmd": "load", "components": ["dcdext"]}`
38 	void load()
39 	{
40 		if (!refInstance)
41 			return;
42 
43 		config.stringBehavior = StringBehavior.source;
44 	}
45 
46 	/// Finds the immediate surrounding code block at a position or returns CodeBlockInfo.init for none/module block.
47 	/// See_Also: CodeBlockInfo
48 	CodeBlockInfo getCodeBlockRange(string code, int position)
49 	{
50 		auto tokens = getTokensForParser(cast(ubyte[]) code, config, &workspaced.stringCache);
51 		auto parsed = parseModule(tokens, "getCodeBlockRange_input.d", &rba);
52 		auto reader = new CodeBlockInfoFinder(position);
53 		reader.visit(parsed);
54 		return reader.block;
55 	}
56 
57 	/// Inserts a generic method after the corresponding block inside the scope where position is.
58 	/// If it can't find a good spot it will insert the code properly indented ata fitting location.
59 	// make public once usable
60 	private CodeReplacement[] insertCodeInContainer(string insert, string code,
61 			int position, bool insertInLastBlock = true, bool insertAtEnd = true)
62 	{
63 		auto container = getCodeBlockRange(code, position);
64 
65 		string codeBlock = code[container.innerRange[0] .. container.innerRange[1]];
66 
67 		scope tokensInsert = getTokensForParser(cast(ubyte[]) insert, config,
68 				&workspaced.stringCache);
69 		scope parsedInsert = parseModule(tokensInsert, "insertCode_insert.d", &rba);
70 
71 		scope insertReader = new CodeDefinitionClassifier(insert);
72 		insertReader.visit(parsedInsert);
73 		scope insertRegions = insertReader.regions.sort!"a.type < b.type".uniq.array;
74 
75 		scope tokens = getTokensForParser(cast(ubyte[]) codeBlock, config, &workspaced.stringCache);
76 		scope parsed = parseModule(tokens, "insertCode_code.d", &rba);
77 
78 		scope reader = new CodeDefinitionClassifier(codeBlock);
79 		reader.visit(parsed);
80 		scope regions = reader.regions;
81 
82 		CodeReplacement[] ret;
83 
84 		foreach (CodeDefinitionClassifier.Region toInsert; insertRegions)
85 		{
86 			auto insertCode = insert[toInsert.region[0] .. toInsert.region[1]];
87 			scope existing = regions.enumerate.filter!(a => a.value.sameBlockAs(toInsert));
88 			if (existing.empty)
89 			{
90 				auto checkProtection = CodeRegionProtection.init.reduce!"a | b"(
91 						mixableProtection.filter!(a => (a & toInsert.protection) != 0));
92 
93 				bool inIncompatible = false;
94 				bool lastFit = false;
95 				int fittingProtection = -1;
96 				int firstStickyProtection = -1;
97 				int regionAfterFitting = -1;
98 				foreach (i, stickyProtection; regions)
99 				{
100 					if (stickyProtection.affectsFollowing
101 							&& stickyProtection.protection != CodeRegionProtection.init)
102 					{
103 						if (firstStickyProtection == -1)
104 							firstStickyProtection = cast(int) i;
105 
106 						if ((stickyProtection.protection & checkProtection) != 0)
107 						{
108 							fittingProtection = cast(int) i;
109 							lastFit = true;
110 							if (!insertInLastBlock)
111 								break;
112 						}
113 						else
114 						{
115 							if (lastFit)
116 							{
117 								regionAfterFitting = cast(int) i;
118 								lastFit = false;
119 							}
120 							inIncompatible = true;
121 						}
122 					}
123 				}
124 				assert(firstStickyProtection != -1 || !inIncompatible);
125 				assert(regionAfterFitting != -1 || fittingProtection == -1 || !inIncompatible);
126 
127 				if (inIncompatible)
128 				{
129 					int insertRegion = fittingProtection == -1 ? firstStickyProtection : regionAfterFitting;
130 					insertCode = indent(insertCode, regions[insertRegion].minIndentation) ~ "\n\n";
131 					auto len = cast(uint) insertCode.length;
132 
133 					toInsert.region[0] = regions[insertRegion].region[0];
134 					toInsert.region[1] = regions[insertRegion].region[0] + len;
135 					foreach (ref r; regions[insertRegion .. $])
136 					{
137 						r.region[0] += len;
138 						r.region[1] += len;
139 					}
140 				}
141 				else
142 				{
143 					auto lastRegion = regions.back;
144 					insertCode = indent(insertCode, lastRegion.minIndentation);
145 					auto len = cast(uint) insertCode.length;
146 					toInsert.region[0] = lastRegion.region[1];
147 					toInsert.region[1] = lastRegion.region[1] + len;
148 				}
149 				regions ~= toInsert;
150 				ret ~= CodeReplacement([toInsert.region[0], toInsert.region[0]], insertCode);
151 			}
152 			else
153 			{
154 				auto target = insertInLastBlock ? existing.tail(1).front : existing.front;
155 
156 				insertCode = "\n\n" ~ indent(insertCode, regions[target.index].minIndentation);
157 				const codeLength = cast(int) insertCode.length;
158 
159 				if (insertAtEnd)
160 				{
161 					ret ~= CodeReplacement([target.value.region[1], target.value.region[1]], insertCode);
162 					toInsert.region[0] = target.value.region[1];
163 					toInsert.region[1] = target.value.region[1] + codeLength;
164 					regions[target.index].region[1] = toInsert.region[1];
165 					foreach (ref other; regions[target.index + 1 .. $])
166 					{
167 						other.region[0] += codeLength;
168 						other.region[1] += codeLength;
169 					}
170 				}
171 				else
172 				{
173 					ret ~= CodeReplacement([target.value.region[0], target.value.region[0]], insertCode);
174 					regions[target.index].region[1] += codeLength;
175 					foreach (ref other; regions[target.index + 1 .. $])
176 					{
177 						other.region[0] += codeLength;
178 						other.region[1] += codeLength;
179 					}
180 				}
181 			}
182 		}
183 
184 		return ret;
185 	}
186 
187 	/// Implements the interfaces or abstract classes of a specified class/interface.
188 	Future!string implement(string code, int position)
189 	{
190 		auto ret = new Future!string;
191 		threads.create({
192 			try
193 			{
194 				struct InterfaceTree
195 				{
196 					InterfaceDetails details;
197 					InterfaceTree[] inherits;
198 				}
199 
200 				auto baseInterface = getInterfaceDetails("stdin", code, position);
201 
202 				string[] implementedMethods = baseInterface.methods
203 					.filter!"!a.needsImplementation"
204 					.map!"a.identifier"
205 					.array;
206 
207 				// start with private, add all the public ones later in traverseTree
208 				FieldDetails[] availableVariables = baseInterface.fields.filter!"a.isPrivate".array;
209 				InterfaceTree tree = InterfaceTree(baseInterface);
210 
211 				InterfaceTree* treeByName(InterfaceTree* tree, string name)
212 				{
213 					if (tree.details.name == name)
214 						return tree;
215 					foreach (ref parent; tree.inherits)
216 					{
217 						InterfaceTree* t = treeByName(&parent, name);
218 						if (t !is null)
219 							return t;
220 					}
221 					return null;
222 				}
223 
224 				void traverseTree(ref InterfaceTree sub)
225 				{
226 					availableVariables ~= sub.details.fields.filter!"!a.isPrivate".array;
227 					foreach (i, parent; sub.details.parentPositions)
228 					{
229 						string parentName = sub.details.normalizedParents[i];
230 						if (treeByName(&tree, parentName) is null)
231 						{
232 							auto details = lookupInterface(sub.details.code, parent);
233 							sub.inherits ~= InterfaceTree(details);
234 						}
235 					}
236 					foreach (ref inherit; sub.inherits)
237 						traverseTree(inherit);
238 				}
239 
240 				traverseTree(tree);
241 
242 				string changes;
243 				void processTree(ref InterfaceTree tree)
244 				{
245 					auto details = tree.details;
246 					if (details.methods.length)
247 					{
248 						bool first = true;
249 						foreach (fn; details.methods)
250 						{
251 							if (implementedMethods.canFind(fn.identifier))
252 								continue;
253 							if (!fn.needsImplementation)
254 							{
255 								implementedMethods ~= fn.identifier;
256 								continue;
257 							}
258 							if (first)
259 							{
260 								changes ~= "// implement " ~ details.name ~ "\n\n";
261 								first = false;
262 							}
263 							if (details.needsOverride)
264 								changes ~= "override ";
265 							changes ~= fn.signature[0 .. $ - 1];
266 							changes ~= " {";
267 							if (fn.optionalImplementation)
268 							{
269 								changes ~= "\n\t// TODO: optional implementation\n";
270 							}
271 
272 							string propertySearch;
273 							if (fn.signature.canFind("@property") && fn.arguments.length <= 1)
274 								propertySearch = fn.name;
275 							else if ((fn.name.startsWith("get") && fn.arguments.length == 0)
276 								|| (fn.name.startsWith("set") && fn.arguments.length == 1))
277 								propertySearch = fn.name[3 .. $];
278 
279 							string foundProperty;
280 							if (propertySearch)
281 							{
282 								foreach (variable; availableVariables)
283 								{
284 									if (fieldNameMatches(variable.name, propertySearch))
285 									{
286 										foundProperty = variable.name;
287 										break;
288 									}
289 								}
290 							}
291 
292 							if (foundProperty.length)
293 							{
294 								changes ~= "\n\t";
295 								if (fn.returnType != "void")
296 									changes ~= "return ";
297 								if (fn.name.startsWith("set") || fn.arguments.length == 1)
298 									changes ~= foundProperty ~ " = " ~ fn.arguments[0].name;
299 								else
300 									changes ~= foundProperty;
301 								changes ~= ";\n";
302 							}
303 							else if (fn.hasBody)
304 							{
305 								changes ~= "\n\t";
306 								if (fn.returnType != "void")
307 									changes ~= "return ";
308 								changes ~= "super." ~ fn.name;
309 								if (fn.arguments.length)
310 									changes ~= "(" ~ format("%(%s, %)", fn.arguments) ~ ")";
311 								else if (fn.returnType == "void")
312 									changes ~= "()"; // make functions that don't return add (), otherwise they might be attributes and don't need that
313 								changes ~= ";\n";
314 							}
315 							else if (fn.returnType != "void")
316 							{
317 								changes ~= "\n\t";
318 								if (fn.isNothrowOrNogc)
319 								{
320 									if (fn.returnType.endsWith("[]"))
321 										changes ~= "return null; // TODO: implement";
322 									else
323 										changes ~= "return " ~ fn.returnType ~ ".init; // TODO: implement";
324 								}
325 								else
326 									changes ~= `assert(false, "Method ` ~ fn.name ~ ` not implemented");`;
327 								changes ~= "\n";
328 							}
329 							changes ~= "}\n\n";
330 						}
331 					}
332 
333 					foreach (parent; tree.inherits)
334 						processTree(parent);
335 				}
336 
337 				processTree(tree);
338 
339 				ret.finish(changes);
340 			}
341 			catch (Throwable t)
342 			{
343 				ret.error(t);
344 			}
345 		});
346 		return ret;
347 	}
348 
349 private:
350 	RollbackAllocator rba;
351 	LexerConfig config;
352 
353 	InterfaceDetails lookupInterface(string code, int position)
354 	{
355 		auto data = get!DCDComponent.findDeclaration(code, position).getBlocking;
356 		string file = data.file;
357 		int newPosition = data.position;
358 
359 		if (!file.length)
360 			return InterfaceDetails.init;
361 
362 		string newCode = code;
363 		if (file != "stdin")
364 			newCode = readText(file);
365 
366 		return getInterfaceDetails(file, newCode, newPosition);
367 	}
368 
369 	InterfaceDetails getInterfaceDetails(string file, string code, int position)
370 	{
371 		auto tokens = getTokensForParser(cast(ubyte[]) code, config, &workspaced.stringCache);
372 		auto parsed = parseModule(tokens, file, &rba);
373 		auto reader = new InterfaceMethodFinder(code, position);
374 		reader.visit(parsed);
375 		return reader.details;
376 	}
377 }
378 
379 ///
380 enum CodeRegionType : int
381 {
382 	/// null region (unset)
383 	init,
384 	/// Imports inside the block
385 	imports = 1 << 0,
386 	/// Aliases `alias foo this;`, `alias Type = Other;`
387 	aliases = 1 << 1,
388 	/// Nested classes/structs/unions/etc.
389 	types = 1 << 2,
390 	/// Raw variables `Type name;`
391 	fields = 1 << 3,
392 	/// Normal constructors `this(Args args)`
393 	ctor = 1 << 4,
394 	/// Copy constructors `this(this)`
395 	copyctor = 1 << 5,
396 	/// Destructors `~this()`
397 	dtor = 1 << 6,
398 	/// Properties (functions annotated with `@property`)
399 	properties = 1 << 7,
400 	/// Regular functions
401 	methods = 1 << 8,
402 }
403 
404 ///
405 enum CodeRegionProtection : int
406 {
407 	/// null protection (unset)
408 	init,
409 	/// default (unmarked) protection
410 	default_ = 1 << 0,
411 	/// public protection
412 	public_ = 1 << 1,
413 	/// package (automatic) protection
414 	package_ = 1 << 2,
415 	/// package (manual package name) protection
416 	packageIdentifier = 1 << 3,
417 	/// protected protection
418 	protected_ = 1 << 4,
419 	/// private protection
420 	private_ = 1 << 5,
421 }
422 
423 ///
424 enum CodeRegionStatic : int
425 {
426 	/// null static (unset)
427 	init,
428 	/// non-static code
429 	instanced = 1 << 0,
430 	/// static code
431 	static_ = 1 << 1,
432 }
433 
434 /// Represents a class/interface/struct/union/template with body.
435 struct CodeBlockInfo
436 {
437 	///
438 	enum Type : int
439 	{
440 		// keep the underlines in these values for range checking properly
441 
442 		///
443 		class_,
444 		///
445 		interface_,
446 		///
447 		struct_,
448 		///
449 		union_,
450 		///
451 		template_,
452 	}
453 
454 	static immutable string[] typePrefixes = [
455 		"class ", "interface ", "struct ", "union ", "template "
456 	];
457 
458 	///
459 	Type type;
460 	///
461 	string name;
462 	/// Outer range inside the code spanning curly braces and name but not type keyword.
463 	uint[2] outerRange;
464 	/// Inner range of body of the block touching, but not spanning curly braces.
465 	uint[2] innerRange;
466 
467 	string prefix() @property
468 	{
469 		return typePrefixes[cast(int) type];
470 	}
471 }
472 
473 private:
474 
475 string indent(string code, string indentation)
476 {
477 	return code.lineSplitter!(KeepTerminator.yes)
478 		.map!(a => a.length ? indentation ~ a : a)
479 		.join;
480 }
481 
482 bool fieldNameMatches(string field, in char[] expected)
483 {
484 	import std.uni : sicmp;
485 
486 	if (field.startsWith("_"))
487 		field = field[1 .. $];
488 	else if (field.startsWith("m_"))
489 		field = field[2 .. $];
490 	else if (field.length >= 2 && field[0] == 'm' && field[1].isUpper)
491 		field = field[1 .. $];
492 
493 	return field.sicmp(expected) == 0;
494 }
495 
496 final class CodeBlockInfoFinder : ASTVisitor
497 {
498 	this(int targetPosition)
499 	{
500 		this.targetPosition = targetPosition;
501 	}
502 
503 	override void visit(const ClassDeclaration dec)
504 	{
505 		visitContainer(dec.name, CodeBlockInfo.Type.class_, dec.structBody);
506 	}
507 
508 	override void visit(const InterfaceDeclaration dec)
509 	{
510 		visitContainer(dec.name, CodeBlockInfo.Type.interface_, dec.structBody);
511 	}
512 
513 	override void visit(const StructDeclaration dec)
514 	{
515 		visitContainer(dec.name, CodeBlockInfo.Type.struct_, dec.structBody);
516 	}
517 
518 	override void visit(const UnionDeclaration dec)
519 	{
520 		visitContainer(dec.name, CodeBlockInfo.Type.union_, dec.structBody);
521 	}
522 
523 	override void visit(const TemplateDeclaration dec)
524 	{
525 		if (cast(int) targetPosition >= cast(int) dec.name.index && targetPosition < dec.endLocation)
526 		{
527 			block = CodeBlockInfo.init;
528 			block.type = CodeBlockInfo.Type.template_;
529 			block.name = dec.name.text;
530 			block.outerRange = [cast(uint) dec.name.index, cast(uint) dec.endLocation + 1];
531 			block.innerRange = [cast(uint) dec.startLocation + 1, cast(uint) dec.endLocation];
532 			dec.accept(this);
533 		}
534 	}
535 
536 	private void visitContainer(const Token name, CodeBlockInfo.Type type, const StructBody structBody)
537 	{
538 		if (!structBody)
539 			return;
540 		if (cast(int) targetPosition >= cast(int) name.index && targetPosition < structBody.endLocation)
541 		{
542 			block = CodeBlockInfo.init;
543 			block.type = type;
544 			block.name = name.text;
545 			block.outerRange = [cast(uint) name.index, cast(uint) structBody.endLocation + 1];
546 			block.innerRange = [cast(uint) structBody.startLocation + 1, cast(uint) structBody
547 				.endLocation];
548 			structBody.accept(this);
549 		}
550 	}
551 
552 	alias visit = ASTVisitor.visit;
553 
554 	CodeBlockInfo block;
555 	int targetPosition;
556 }
557 
558 version (unittest) static immutable string SimpleClassTestCode = q{
559 module foo;
560 
561 class FooBar
562 {
563 public:
564 	int i; // default instanced fields
565 	string s;
566 	long l;
567 
568 	public this() // public instanced ctor
569 	{
570 		i = 4;
571 	}
572 
573 protected:
574 	int x; // protected instanced field
575 
576 private:
577 	static const int foo() @nogc nothrow pure @system // private static methods
578 	{
579 		if (s == "a")
580 		{
581 			i = 5;
582 		}
583 	}
584 
585 	static void bar1() {}
586 
587 	void bar2() {} // private instanced methods
588 	void bar3() {}
589 }};
590 
591 unittest
592 {
593 	auto backend = new WorkspaceD();
594 	auto workspace = makeTemporaryTestingWorkspace;
595 	auto instance = backend.addInstance(workspace.directory);
596 	backend.register!DCDExtComponent;
597 	DCDExtComponent dcdext = instance.get!DCDExtComponent;
598 
599 	assert(dcdext.getCodeBlockRange(SimpleClassTestCode, 123) == CodeBlockInfo(CodeBlockInfo.Type.class_,
600 			"FooBar", [20, SimpleClassTestCode.length], [28, SimpleClassTestCode.length - 1]));
601 	assert(dcdext.getCodeBlockRange(SimpleClassTestCode, 19) == CodeBlockInfo.init);
602 	assert(dcdext.getCodeBlockRange(SimpleClassTestCode, 20) != CodeBlockInfo.init);
603 
604 	auto replacements = dcdext.insertCodeInContainer("void foo()\n{\n\twriteln();\n}",
605 			SimpleClassTestCode, 123);
606 	import std.stdio;
607 
608 	stderr.writeln(replacements);
609 }