1//
2//   Copyright 2013 Pixar
3//
4//   Licensed under the Apache License, Version 2.0 (the "Apache License")
5//   with the following modification; you may not use this file except in
6//   compliance with the Apache License and the following modification to it:
7//   Section 6. Trademarks. is deleted and replaced with:
8//
9//   6. Trademarks. This License does not grant permission to use the trade
10//      names, trademarks, service marks, or product names of the Licensor
11//      and its affiliates, except as required to comply with Section 4(c) of
12//      the License and to reproduce the content of the NOTICE file.
13//
14//   You may obtain a copy of the Apache License at
15//
16//       http://www.apache.org/licenses/LICENSE-2.0
17//
18//   Unless required by applicable law or agreed to in writing, software
19//   distributed under the Apache License with the above modification is
20//   distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
21//   KIND, either express or implied. See the Apache License for the specific
22//   language governing permissions and limitations under the Apache License.
23//
24
25#include "mtlControlMeshDisplay.h"
26#include <cassert>
27#include "mtlUtils.h"
28
29static const char* s_Shader = R"(
30#include <metal_stdlib>
31using namespace metal;
32
33float4 sharpnessToColor(float s) {
34  //  0.0       2.0       4.0
35  // green --- yellow --- red
36  return float4(min(1.0, s * 0.5),
37                min(1.0, 2.0 - s * 0.5),
38                0.0, 1.0);
39}
40
41struct DrawData
42{
43	float4x4 ModelViewProjectionMatrix;
44};
45
46struct VertexData
47{
48	float3 position [[attribute(0)]];
49	float sharpness [[attribute(1)]];
50};
51
52struct FragmentData
53{
54	float4 position [[position]];
55	float4 color;
56};
57
58vertex FragmentData vs_main(VertexData in [[stage_in]],
59			   const constant DrawData& drawData [[buffer(2)]]
60			   )
61{
62	FragmentData out;
63	out.position = drawData.ModelViewProjectionMatrix * float4(in.position, 1.0);
64	out.color = sharpnessToColor(in.sharpness);
65	return out;
66}
67
68fragment float4 fs_main(FragmentData in [[stage_in]])
69{
70    return in.color;
71}
72)";
73
74MTLControlMeshDisplay::MTLControlMeshDisplay(id<MTLDevice> device, MTLRenderPipelineDescriptor* pipelineDescriptor)
75	: _device(device), _displayEdges(false), _displayVertices(false), _numEdges(0), _numPoints(0) {
76    const auto result = createProgram(pipelineDescriptor);
77	assert(result && "Failed to create program for MTLControlMeshDisplay");
78}
79
80 void MTLControlMeshDisplay::SetTopology(OpenSubdiv::Far::TopologyLevel const &level) {
81    using namespace OpenSubdiv;
82
83	_numEdges = level.GetNumEdges();
84	_numPoints = level.GetNumVertices();
85
86	std::vector<int> edgeIndices;
87	std::vector<float> edgeSharpness;
88	std::vector<float> vertSharpness;
89
90	edgeIndices.reserve(_numEdges * 2);
91	edgeSharpness.reserve(_numEdges);
92	vertSharpness.reserve(_numPoints);
93
94	for(int i = 0; i < _numEdges; i++) {
95	    const auto verts = level.GetEdgeVertices(i);
96	    edgeIndices.emplace_back(verts[0]);
97	    edgeIndices.emplace_back(verts[1]);
98	    edgeSharpness.emplace_back(level.GetEdgeSharpness(i));
99	}
100
101	for(int i = 0; i < _numPoints; i++) {
102	    vertSharpness.emplace_back(level.GetVertexSharpness(i));
103	}
104
105    _edgeIndicesBuffer = Osd::MTLNewBufferFromVector(_device, edgeIndices);
106    _edgeSharpnessBuffer = Osd::MTLNewBufferFromVector(_device, edgeSharpness);
107    _vertexSharpnessBuffer = Osd::MTLNewBufferFromVector(_device, vertSharpness);
108}
109
110bool MTLControlMeshDisplay::createProgram(MTLRenderPipelineDescriptor* _pipelineDescriptor) {
111	const auto options = [MTLCompileOptions new];
112	NSError* error = nil;
113
114	const auto library = [_device newLibraryWithSource:@(s_Shader) options:options error:&error];
115	if(!library) {
116        printf("Failed to create library for MTLControlMeshDisplay\n%s\n", error ? [[error localizedDescription] UTF8String] : "");
117		return false;
118	}
119
120	const auto vertexFunction = [library newFunctionWithName:@"vs_main"];
121	const auto fragmentFunction = [library newFunctionWithName:@"fs_main"];
122
123	MTLRenderPipelineDescriptor* pipelineDescriptor = [_pipelineDescriptor copy];
124	pipelineDescriptor.vertexFunction = vertexFunction;
125	pipelineDescriptor.fragmentFunction = fragmentFunction;
126	const auto vertexDescriptor = pipelineDescriptor.vertexDescriptor;
127	vertexDescriptor.layouts[1].stride = sizeof(float) * 6;
128	vertexDescriptor.layouts[1].stepFunction = MTLVertexStepFunctionPerVertex;
129	vertexDescriptor.layouts[1].stepRate = 1;
130	vertexDescriptor.attributes[1].bufferIndex = 1;
131	vertexDescriptor.attributes[1].offset = 0;
132	vertexDescriptor.attributes[1].format = MTLVertexFormatFloat3;
133
134	_renderPipelineState = [_device newRenderPipelineStateWithDescriptor:pipelineDescriptor error:&error];
135
136	if(!_renderPipelineState) {
137        printf("Failed to create render pipeline state for MTLControlMeshDisplay\n%s\n", error ? [[error localizedDescription] UTF8String] : "");
138	}
139	return true;
140}
141
142void MTLControlMeshDisplay::Draw(id<MTLRenderCommandEncoder> encoder,
143    id<MTLBuffer> vertexBuffer,
144    const float *modelViewProjectionMatrix) {
145	[encoder setRenderPipelineState: _renderPipelineState];
146	[encoder setVertexBuffer:vertexBuffer offset:0 atIndex:0];
147	[encoder setVertexBytes:modelViewProjectionMatrix length: sizeof(float) * 16 atIndex:2];
148
149	if(_displayEdges) {
150		[encoder setVertexBuffer:_edgeSharpnessBuffer offset:0 atIndex:1];
151		[encoder drawIndexedPrimitives:MTLPrimitiveTypeLine
152                            indexCount:_numEdges * 2
153                             indexType:MTLIndexTypeUInt32
154                           indexBuffer:_edgeIndicesBuffer
155                     indexBufferOffset:0
156                         instanceCount:1
157                            baseVertex:0
158                          baseInstance:0];
159	}
160
161	if(_displayVertices) {
162		[encoder setVertexBuffer:_vertexSharpnessBuffer offset:0 atIndex:1];
163		[encoder drawPrimitives:MTLPrimitiveTypePoint
164                    vertexStart:0
165                    vertexCount:_numPoints];
166	}
167}
168